Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

233 statements  

1import asyncio 

2import signal 

3import socket 

4from abc import ABC, abstractmethod 

5from typing import Any, Generic, TypeVar 

6 

7from yarl import URL 

8 

9from .abc import AbstractAccessLogger, AbstractStreamWriter 

10from .http_parser import RawRequestMessage 

11from .streams import StreamReader 

12from .typedefs import PathLike 

13from .web_app import Application 

14from .web_log import AccessLogger 

15from .web_protocol import RequestHandler 

16from .web_request import BaseRequest, Request 

17from .web_server import Server 

18 

19try: 

20 from ssl import SSLContext 

21except ImportError: # pragma: no cover 

22 SSLContext = object # type: ignore[misc,assignment] 

23 

24__all__ = ( 

25 "BaseSite", 

26 "TCPSite", 

27 "UnixSite", 

28 "NamedPipeSite", 

29 "SockSite", 

30 "BaseRunner", 

31 "AppRunner", 

32 "ServerRunner", 

33 "GracefulExit", 

34) 

35 

36_Request = TypeVar("_Request", bound=BaseRequest) 

37 

38 

39class GracefulExit(SystemExit): 

40 code = 1 

41 

42 

43def _raise_graceful_exit() -> None: 

44 raise GracefulExit() 

45 

46 

47class BaseSite(ABC): 

48 __slots__ = ("_runner", "_ssl_context", "_backlog", "_server") 

49 

50 def __init__( 

51 self, 

52 runner: "BaseRunner[Any]", 

53 *, 

54 ssl_context: SSLContext | None = None, 

55 backlog: int = 128, 

56 ) -> None: 

57 if runner.server is None: 

58 raise RuntimeError("Call runner.setup() before making a site") 

59 self._runner = runner 

60 self._ssl_context = ssl_context 

61 self._backlog = backlog 

62 self._server: asyncio.Server | None = None 

63 

64 @property 

65 @abstractmethod 

66 def name(self) -> str: 

67 """Return the name of the site (e.g. a URL).""" 

68 

69 @abstractmethod 

70 async def start(self) -> None: 

71 self._runner._reg_site(self) 

72 

73 async def stop(self) -> None: 

74 self._runner._check_site(self) 

75 if self._server is not None: # Maybe not started yet 

76 self._server.close() 

77 

78 self._runner._unreg_site(self) 

79 

80 

81class TCPSite(BaseSite): 

82 __slots__ = ("_host", "_port", "_bound_port", "_reuse_address", "_reuse_port") 

83 

84 def __init__( 

85 self, 

86 runner: "BaseRunner[Any]", 

87 host: str | None = None, 

88 port: int | None = None, 

89 *, 

90 ssl_context: SSLContext | None = None, 

91 backlog: int = 128, 

92 reuse_address: bool | None = None, 

93 reuse_port: bool | None = None, 

94 ) -> None: 

95 super().__init__( 

96 runner, 

97 ssl_context=ssl_context, 

98 backlog=backlog, 

99 ) 

100 self._host = host 

101 if port is None: 

102 port = 8443 if self._ssl_context else 8080 

103 self._port = port 

104 self._bound_port: int | None = None 

105 self._reuse_address = reuse_address 

106 self._reuse_port = reuse_port 

107 

108 @property 

109 def port(self) -> int: 

110 """The port the server is listening on. 

111 

112 If the server hasn't been started yet, this returns the requested port 

113 (which might be 0 for a dynamic port). 

114 After the server starts, it returns the actual bound port. This is 

115 especially useful when port=0 was requested, as it allows retrieving the 

116 dynamically assigned port after the site has started. 

117 """ 

118 if self._bound_port is not None: 

119 return self._bound_port 

120 return self._port 

121 

122 @property 

123 def name(self) -> str: 

124 scheme = "https" if self._ssl_context else "http" 

125 host = "0.0.0.0" if not self._host else self._host 

126 return str(URL.build(scheme=scheme, host=host, port=self.port)) 

127 

128 async def start(self) -> None: 

129 await super().start() 

130 loop = asyncio.get_event_loop() 

131 server = self._runner.server 

132 assert server is not None 

133 self._server = await loop.create_server( 

134 server, 

135 self._host, 

136 self._port, 

137 ssl=self._ssl_context, 

138 backlog=self._backlog, 

139 reuse_address=self._reuse_address, 

140 reuse_port=self._reuse_port, 

141 ) 

142 if self._server.sockets: 

143 self._bound_port = self._server.sockets[0].getsockname()[1] 

144 else: 

145 self._bound_port = self._port 

146 

147 

148class UnixSite(BaseSite): 

149 __slots__ = ("_path",) 

150 

151 def __init__( 

152 self, 

153 runner: "BaseRunner[Any]", 

154 path: PathLike, 

155 *, 

156 ssl_context: SSLContext | None = None, 

157 backlog: int = 128, 

158 ) -> None: 

159 super().__init__( 

160 runner, 

161 ssl_context=ssl_context, 

162 backlog=backlog, 

163 ) 

164 self._path = path 

165 

166 @property 

167 def name(self) -> str: 

168 scheme = "https" if self._ssl_context else "http" 

169 return f"{scheme}://unix:{self._path}:" 

170 

171 async def start(self) -> None: 

172 await super().start() 

173 loop = asyncio.get_event_loop() 

174 server = self._runner.server 

175 assert server is not None 

176 self._server = await loop.create_unix_server( 

177 server, 

178 self._path, 

179 ssl=self._ssl_context, 

180 backlog=self._backlog, 

181 ) 

182 

183 

184class NamedPipeSite(BaseSite): 

185 __slots__ = ("_path",) 

186 

187 def __init__(self, runner: "BaseRunner[Any]", path: str) -> None: 

188 loop = asyncio.get_event_loop() 

189 if not isinstance( 

190 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] 

191 ): 

192 raise RuntimeError( 

193 "Named Pipes only available in proactor loop under windows" 

194 ) 

195 super().__init__(runner) 

196 self._path = path 

197 

198 @property 

199 def name(self) -> str: 

200 return self._path 

201 

202 async def start(self) -> None: 

203 await super().start() 

204 loop = asyncio.get_event_loop() 

205 server = self._runner.server 

206 assert server is not None 

207 _server = await loop.start_serving_pipe( # type: ignore[attr-defined] 

208 server, self._path 

209 ) 

210 self._server = _server[0] 

211 

212 

213class SockSite(BaseSite): 

214 __slots__ = ("_sock", "_name") 

215 

216 def __init__( 

217 self, 

218 runner: "BaseRunner[Any]", 

219 sock: socket.socket, 

220 *, 

221 ssl_context: SSLContext | None = None, 

222 backlog: int = 128, 

223 ) -> None: 

224 super().__init__( 

225 runner, 

226 ssl_context=ssl_context, 

227 backlog=backlog, 

228 ) 

229 self._sock = sock 

230 scheme = "https" if self._ssl_context else "http" 

231 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: 

232 name = f"{scheme}://unix:{sock.getsockname()}:" 

233 else: 

234 host, port = sock.getsockname()[:2] 

235 name = str(URL.build(scheme=scheme, host=host, port=port)) 

236 self._name = name 

237 

238 @property 

239 def name(self) -> str: 

240 return self._name 

241 

242 async def start(self) -> None: 

243 await super().start() 

244 loop = asyncio.get_event_loop() 

245 server = self._runner.server 

246 assert server is not None 

247 self._server = await loop.create_server( 

248 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog 

249 ) 

250 

251 

252class BaseRunner(ABC, Generic[_Request]): 

253 __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") 

254 

255 def __init__( 

256 self, 

257 *, 

258 handle_signals: bool = False, 

259 shutdown_timeout: float = 60.0, 

260 **kwargs: Any, 

261 ) -> None: 

262 self._handle_signals = handle_signals 

263 self._kwargs = kwargs 

264 self._server: Server[_Request] | None = None 

265 self._sites: list[BaseSite] = [] 

266 self._shutdown_timeout = shutdown_timeout 

267 

268 @property 

269 def server(self) -> Server[_Request] | None: 

270 return self._server 

271 

272 @property 

273 def addresses(self) -> list[Any]: 

274 ret: list[Any] = [] 

275 for site in self._sites: 

276 server = site._server 

277 if server is not None: 

278 sockets = server.sockets 

279 if sockets is not None: 

280 for sock in sockets: 

281 ret.append(sock.getsockname()) 

282 return ret 

283 

284 @property 

285 def sites(self) -> set[BaseSite]: 

286 return set(self._sites) 

287 

288 async def setup(self) -> None: 

289 loop = asyncio.get_event_loop() 

290 

291 if self._handle_signals: 

292 try: 

293 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) 

294 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) 

295 except NotImplementedError: 

296 # add_signal_handler is not implemented on Windows 

297 pass 

298 

299 self._server = await self._make_server() 

300 

301 @abstractmethod 

302 async def shutdown(self) -> None: 

303 """Call any shutdown hooks to help server close gracefully.""" 

304 

305 async def cleanup(self) -> None: 

306 # The loop over sites is intentional, an exception on gather() 

307 # leaves self._sites in unpredictable state. 

308 # The loop guarantees that a site is either deleted on success or 

309 # still present on failure 

310 for site in list(self._sites): 

311 await site.stop() 

312 

313 if self._server: # If setup succeeded 

314 # Yield to event loop to ensure incoming requests prior to stopping the sites 

315 # have all started to be handled before we proceed to close idle connections. 

316 await asyncio.sleep(0) 

317 self._server.pre_shutdown() 

318 await self.shutdown() 

319 await self._server.shutdown(self._shutdown_timeout) 

320 await self._cleanup_server() 

321 

322 self._server = None 

323 if self._handle_signals: 

324 loop = asyncio.get_running_loop() 

325 try: 

326 loop.remove_signal_handler(signal.SIGINT) 

327 loop.remove_signal_handler(signal.SIGTERM) 

328 except NotImplementedError: 

329 # remove_signal_handler is not implemented on Windows 

330 pass 

331 

332 @abstractmethod 

333 async def _make_server(self) -> Server[_Request]: 

334 """Return a new server for the runner to serve requests.""" 

335 

336 @abstractmethod 

337 async def _cleanup_server(self) -> None: 

338 """Run any cleanup steps after the server is shutdown.""" 

339 

340 def _reg_site(self, site: BaseSite) -> None: 

341 if site in self._sites: 

342 raise RuntimeError(f"Site {site} is already registered in runner {self}") 

343 self._sites.append(site) 

344 

345 def _check_site(self, site: BaseSite) -> None: 

346 if site not in self._sites: 

347 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

348 

349 def _unreg_site(self, site: BaseSite) -> None: 

350 if site not in self._sites: 

351 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

352 self._sites.remove(site) 

353 

354 

355class ServerRunner(BaseRunner[BaseRequest]): 

356 """Low-level web server runner""" 

357 

358 __slots__ = ("_web_server",) 

359 

360 def __init__( 

361 self, 

362 web_server: Server[BaseRequest], 

363 *, 

364 handle_signals: bool = False, 

365 **kwargs: Any, 

366 ) -> None: 

367 super().__init__(handle_signals=handle_signals, **kwargs) 

368 self._web_server = web_server 

369 

370 async def shutdown(self) -> None: 

371 pass 

372 

373 async def _make_server(self) -> Server[BaseRequest]: 

374 return self._web_server 

375 

376 async def _cleanup_server(self) -> None: 

377 pass 

378 

379 

380class AppRunner(BaseRunner[Request]): 

381 """Web Application runner""" 

382 

383 __slots__ = ("_app",) 

384 

385 def __init__( 

386 self, 

387 app: Application, 

388 *, 

389 handle_signals: bool = False, 

390 access_log_class: type[AbstractAccessLogger] = AccessLogger, 

391 **kwargs: Any, 

392 ) -> None: 

393 if not isinstance(app, Application): 

394 raise TypeError( 

395 f"The first argument should be web.Application instance, got {app!r}" 

396 ) 

397 kwargs["access_log_class"] = access_log_class 

398 

399 if app._handler_args: 

400 for k, v in app._handler_args.items(): 

401 kwargs[k] = v 

402 

403 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger): 

404 raise TypeError( 

405 "access_log_class must be subclass of " 

406 "aiohttp.abc.AbstractAccessLogger, got {}".format( 

407 kwargs["access_log_class"] 

408 ) 

409 ) 

410 

411 super().__init__(handle_signals=handle_signals, **kwargs) 

412 self._app = app 

413 

414 @property 

415 def app(self) -> Application: 

416 return self._app 

417 

418 async def shutdown(self) -> None: 

419 await self._app.shutdown() 

420 

421 async def _make_server(self) -> Server[Request]: 

422 self._app.on_startup.freeze() 

423 await self._app.startup() 

424 self._app.freeze() 

425 

426 return Server( 

427 self._app._handle, 

428 request_factory=self._make_request, 

429 **self._kwargs, 

430 ) 

431 

432 def _make_request( 

433 self, 

434 message: RawRequestMessage, 

435 payload: StreamReader, 

436 protocol: RequestHandler[Request], 

437 writer: AbstractStreamWriter, 

438 task: "asyncio.Task[None]", 

439 _cls: type[Request] = Request, 

440 ) -> Request: 

441 loop = asyncio.get_running_loop() 

442 return _cls( 

443 message, 

444 payload, 

445 protocol, 

446 writer, 

447 task, 

448 loop, 

449 client_max_size=self.app._client_max_size, 

450 ) 

451 

452 async def _cleanup_server(self) -> None: 

453 await self._app.cleanup()