Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 39%

233 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:52 +0000

1import asyncio 

2import signal 

3import socket 

4from abc import ABC, abstractmethod 

5from contextlib import suppress 

6from typing import Any, List, Optional, Set, Type 

7 

8from yarl import URL 

9 

10from .abc import AbstractAccessLogger, AbstractStreamWriter 

11from .http_parser import RawRequestMessage 

12from .streams import StreamReader 

13from .typedefs import PathLike 

14from .web_app import Application 

15from .web_log import AccessLogger 

16from .web_protocol import RequestHandler 

17from .web_request import Request 

18from .web_server import Server 

19 

20try: 

21 from ssl import SSLContext 

22except ImportError: 

23 SSLContext = object # type: ignore[misc,assignment] 

24 

25 

26__all__ = ( 

27 "BaseSite", 

28 "TCPSite", 

29 "UnixSite", 

30 "NamedPipeSite", 

31 "SockSite", 

32 "BaseRunner", 

33 "AppRunner", 

34 "ServerRunner", 

35 "GracefulExit", 

36) 

37 

38 

39class GracefulExit(SystemExit): 

40 code = 1 

41 

42 

43def _raise_graceful_exit() -> None: 

44 raise GracefulExit() 

45 

46 

47class BaseSite(ABC): 

48 __slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server") 

49 

50 def __init__( 

51 self, 

52 runner: "BaseRunner", 

53 *, 

54 shutdown_timeout: float = 60.0, 

55 ssl_context: Optional[SSLContext] = None, 

56 backlog: int = 128, 

57 ) -> None: 

58 if runner.server is None: 

59 raise RuntimeError("Call runner.setup() before making a site") 

60 self._runner = runner 

61 self._shutdown_timeout = shutdown_timeout 

62 self._ssl_context = ssl_context 

63 self._backlog = backlog 

64 self._server: Optional[asyncio.AbstractServer] = None 

65 

66 @property 

67 @abstractmethod 

68 def name(self) -> str: 

69 pass # pragma: no cover 

70 

71 @abstractmethod 

72 async def start(self) -> None: 

73 self._runner._reg_site(self) 

74 

75 async def stop(self) -> None: 

76 self._runner._check_site(self) 

77 if self._server is None: 

78 self._runner._unreg_site(self) 

79 return # not started yet 

80 self._server.close() 

81 # named pipes do not have wait_closed property 

82 if hasattr(self._server, "wait_closed"): 

83 await self._server.wait_closed() 

84 

85 # Wait for pending tasks for a given time limit. 

86 with suppress(asyncio.TimeoutError): 

87 await asyncio.wait_for( 

88 self._wait(asyncio.current_task()), timeout=self._shutdown_timeout 

89 ) 

90 

91 await self._runner.shutdown() 

92 assert self._runner.server 

93 await self._runner.server.shutdown(self._shutdown_timeout) 

94 self._runner._unreg_site(self) 

95 

96 async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None: 

97 exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task} 

98 # TODO(PY38): while tasks := asyncio.all_tasks() - exclude: 

99 tasks = asyncio.all_tasks() - exclude 

100 while tasks: 

101 await asyncio.wait(tasks) 

102 tasks = asyncio.all_tasks() - exclude 

103 

104 

105class TCPSite(BaseSite): 

106 __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") 

107 

108 def __init__( 

109 self, 

110 runner: "BaseRunner", 

111 host: Optional[str] = None, 

112 port: Optional[int] = None, 

113 *, 

114 shutdown_timeout: float = 60.0, 

115 ssl_context: Optional[SSLContext] = None, 

116 backlog: int = 128, 

117 reuse_address: Optional[bool] = None, 

118 reuse_port: Optional[bool] = None, 

119 ) -> None: 

120 super().__init__( 

121 runner, 

122 shutdown_timeout=shutdown_timeout, 

123 ssl_context=ssl_context, 

124 backlog=backlog, 

125 ) 

126 self._host = host 

127 if port is None: 

128 port = 8443 if self._ssl_context else 8080 

129 self._port = port 

130 self._reuse_address = reuse_address 

131 self._reuse_port = reuse_port 

132 

133 @property 

134 def name(self) -> str: 

135 scheme = "https" if self._ssl_context else "http" 

136 host = "0.0.0.0" if self._host is None else self._host 

137 return str(URL.build(scheme=scheme, host=host, port=self._port)) 

138 

139 async def start(self) -> None: 

140 await super().start() 

141 loop = asyncio.get_event_loop() 

142 server = self._runner.server 

143 assert server is not None 

144 self._server = await loop.create_server( 

145 server, 

146 self._host, 

147 self._port, 

148 ssl=self._ssl_context, 

149 backlog=self._backlog, 

150 reuse_address=self._reuse_address, 

151 reuse_port=self._reuse_port, 

152 ) 

153 

154 

155class UnixSite(BaseSite): 

156 __slots__ = ("_path",) 

157 

158 def __init__( 

159 self, 

160 runner: "BaseRunner", 

161 path: PathLike, 

162 *, 

163 shutdown_timeout: float = 60.0, 

164 ssl_context: Optional[SSLContext] = None, 

165 backlog: int = 128, 

166 ) -> None: 

167 super().__init__( 

168 runner, 

169 shutdown_timeout=shutdown_timeout, 

170 ssl_context=ssl_context, 

171 backlog=backlog, 

172 ) 

173 self._path = path 

174 

175 @property 

176 def name(self) -> str: 

177 scheme = "https" if self._ssl_context else "http" 

178 return f"{scheme}://unix:{self._path}:" 

179 

180 async def start(self) -> None: 

181 await super().start() 

182 loop = asyncio.get_event_loop() 

183 server = self._runner.server 

184 assert server is not None 

185 self._server = await loop.create_unix_server( 

186 server, 

187 self._path, 

188 ssl=self._ssl_context, 

189 backlog=self._backlog, 

190 ) 

191 

192 

193class NamedPipeSite(BaseSite): 

194 __slots__ = ("_path",) 

195 

196 def __init__( 

197 self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0 

198 ) -> None: 

199 loop = asyncio.get_event_loop() 

200 if not isinstance( 

201 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] 

202 ): 

203 raise RuntimeError( 

204 "Named Pipes only available in proactor" "loop under windows" 

205 ) 

206 super().__init__(runner, shutdown_timeout=shutdown_timeout) 

207 self._path = path 

208 

209 @property 

210 def name(self) -> str: 

211 return self._path 

212 

213 async def start(self) -> None: 

214 await super().start() 

215 loop = asyncio.get_event_loop() 

216 server = self._runner.server 

217 assert server is not None 

218 _server = await loop.start_serving_pipe( # type: ignore[attr-defined] 

219 server, self._path 

220 ) 

221 self._server = _server[0] 

222 

223 

224class SockSite(BaseSite): 

225 __slots__ = ("_sock", "_name") 

226 

227 def __init__( 

228 self, 

229 runner: "BaseRunner", 

230 sock: socket.socket, 

231 *, 

232 shutdown_timeout: float = 60.0, 

233 ssl_context: Optional[SSLContext] = None, 

234 backlog: int = 128, 

235 ) -> None: 

236 super().__init__( 

237 runner, 

238 shutdown_timeout=shutdown_timeout, 

239 ssl_context=ssl_context, 

240 backlog=backlog, 

241 ) 

242 self._sock = sock 

243 scheme = "https" if self._ssl_context else "http" 

244 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: 

245 name = f"{scheme}://unix:{sock.getsockname()}:" 

246 else: 

247 host, port = sock.getsockname()[:2] 

248 name = str(URL.build(scheme=scheme, host=host, port=port)) 

249 self._name = name 

250 

251 @property 

252 def name(self) -> str: 

253 return self._name 

254 

255 async def start(self) -> None: 

256 await super().start() 

257 loop = asyncio.get_event_loop() 

258 server = self._runner.server 

259 assert server is not None 

260 self._server = await loop.create_server( 

261 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog 

262 ) 

263 

264 

265class BaseRunner(ABC): 

266 __slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites") 

267 

268 def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None: 

269 self._handle_signals = handle_signals 

270 self._kwargs = kwargs 

271 self._server: Optional[Server] = None 

272 self._sites: List[BaseSite] = [] 

273 

274 @property 

275 def server(self) -> Optional[Server]: 

276 return self._server 

277 

278 @property 

279 def addresses(self) -> List[Any]: 

280 ret: List[Any] = [] 

281 for site in self._sites: 

282 server = site._server 

283 if server is not None: 

284 sockets = server.sockets # type: ignore[attr-defined] 

285 if sockets is not None: 

286 for sock in sockets: 

287 ret.append(sock.getsockname()) 

288 return ret 

289 

290 @property 

291 def sites(self) -> Set[BaseSite]: 

292 return set(self._sites) 

293 

294 async def setup(self) -> None: 

295 loop = asyncio.get_event_loop() 

296 

297 if self._handle_signals: 

298 try: 

299 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) 

300 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) 

301 except NotImplementedError: # pragma: no cover 

302 # add_signal_handler is not implemented on Windows 

303 pass 

304 

305 self._server = await self._make_server() 

306 # On shutdown we want to avoid waiting on tasks which run forever. 

307 # It's very likely that all tasks which run forever will have been created by 

308 # the time we have completed the application startup (in self._make_server()), 

309 # so we just record all running tasks here and exclude them later. 

310 self.starting_tasks = asyncio.all_tasks() 

311 

312 @abstractmethod 

313 async def shutdown(self) -> None: 

314 pass # pragma: no cover 

315 

316 async def cleanup(self) -> None: 

317 loop = asyncio.get_event_loop() 

318 

319 # The loop over sites is intentional, an exception on gather() 

320 # leaves self._sites in unpredictable state. 

321 # The loop guarantees that a site is either deleted on success or 

322 # still present on failure 

323 for site in list(self._sites): 

324 await site.stop() 

325 await self._cleanup_server() 

326 self._server = None 

327 if self._handle_signals: 

328 try: 

329 loop.remove_signal_handler(signal.SIGINT) 

330 loop.remove_signal_handler(signal.SIGTERM) 

331 except NotImplementedError: # pragma: no cover 

332 # remove_signal_handler is not implemented on Windows 

333 pass 

334 

335 @abstractmethod 

336 async def _make_server(self) -> Server: 

337 pass # pragma: no cover 

338 

339 @abstractmethod 

340 async def _cleanup_server(self) -> None: 

341 pass # pragma: no cover 

342 

343 def _reg_site(self, site: BaseSite) -> None: 

344 if site in self._sites: 

345 raise RuntimeError(f"Site {site} is already registered in runner {self}") 

346 self._sites.append(site) 

347 

348 def _check_site(self, site: BaseSite) -> None: 

349 if site not in self._sites: 

350 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

351 

352 def _unreg_site(self, site: BaseSite) -> None: 

353 if site not in self._sites: 

354 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

355 self._sites.remove(site) 

356 

357 

358class ServerRunner(BaseRunner): 

359 """Low-level web server runner""" 

360 

361 __slots__ = ("_web_server",) 

362 

363 def __init__( 

364 self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any 

365 ) -> None: 

366 super().__init__(handle_signals=handle_signals, **kwargs) 

367 self._web_server = web_server 

368 

369 async def shutdown(self) -> None: 

370 pass 

371 

372 async def _make_server(self) -> Server: 

373 return self._web_server 

374 

375 async def _cleanup_server(self) -> None: 

376 pass 

377 

378 

379class AppRunner(BaseRunner): 

380 """Web Application runner""" 

381 

382 __slots__ = ("_app",) 

383 

384 def __init__( 

385 self, 

386 app: Application, 

387 *, 

388 handle_signals: bool = False, 

389 access_log_class: Type[AbstractAccessLogger] = AccessLogger, 

390 **kwargs: Any, 

391 ) -> None: 

392 if not isinstance(app, Application): 

393 raise TypeError( 

394 "The first argument should be web.Application " 

395 "instance, got {!r}".format(app) 

396 ) 

397 kwargs["access_log_class"] = access_log_class 

398 

399 if app._handler_args: 

400 for k, v in app._handler_args.items(): 

401 kwargs[k] = v 

402 

403 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger): 

404 raise TypeError( 

405 "access_log_class must be subclass of " 

406 "aiohttp.abc.AbstractAccessLogger, got {}".format( 

407 kwargs["access_log_class"] 

408 ) 

409 ) 

410 

411 super().__init__(handle_signals=handle_signals, **kwargs) 

412 self._app = app 

413 

414 @property 

415 def app(self) -> Application: 

416 return self._app 

417 

418 async def shutdown(self) -> None: 

419 await self._app.shutdown() 

420 

421 async def _make_server(self) -> Server: 

422 self._app.on_startup.freeze() 

423 await self._app.startup() 

424 self._app.freeze() 

425 

426 return Server( 

427 self._app._handle, # type: ignore[arg-type] 

428 request_factory=self._make_request, 

429 **self._kwargs, 

430 ) 

431 

432 def _make_request( 

433 self, 

434 message: RawRequestMessage, 

435 payload: StreamReader, 

436 protocol: RequestHandler, 

437 writer: AbstractStreamWriter, 

438 task: "asyncio.Task[None]", 

439 _cls: Type[Request] = Request, 

440 ) -> Request: 

441 loop = asyncio.get_running_loop() 

442 return _cls( 

443 message, 

444 payload, 

445 protocol, 

446 writer, 

447 task, 

448 loop, 

449 client_max_size=self.app._client_max_size, 

450 ) 

451 

452 async def _cleanup_server(self) -> None: 

453 await self._app.cleanup()