Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 40%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

226 statements  

1import asyncio 

2import signal 

3import socket 

4from abc import ABC, abstractmethod 

5from typing import TYPE_CHECKING, Any, Generic, List, Optional, Set, Type, TypeVar 

6 

7from yarl import URL 

8 

9from .abc import AbstractAccessLogger, AbstractStreamWriter 

10from .http_parser import RawRequestMessage 

11from .streams import StreamReader 

12from .typedefs import PathLike 

13from .web_app import Application 

14from .web_log import AccessLogger 

15from .web_protocol import RequestHandler 

16from .web_request import BaseRequest, Request 

17from .web_server import Server 

18 

19if TYPE_CHECKING: 

20 from ssl import SSLContext 

21else: 

22 try: 

23 from ssl import SSLContext 

24 except ImportError: # pragma: no cover 

25 SSLContext = object # type: ignore[misc,assignment] 

26 

27__all__ = ( 

28 "BaseSite", 

29 "TCPSite", 

30 "UnixSite", 

31 "NamedPipeSite", 

32 "SockSite", 

33 "BaseRunner", 

34 "AppRunner", 

35 "ServerRunner", 

36 "GracefulExit", 

37) 

38 

39_Request = TypeVar("_Request", bound=BaseRequest) 

40 

41 

42class GracefulExit(SystemExit): 

43 code = 1 

44 

45 

46def _raise_graceful_exit() -> None: 

47 raise GracefulExit() 

48 

49 

50class BaseSite(ABC): 

51 __slots__ = ("_runner", "_ssl_context", "_backlog", "_server") 

52 

53 def __init__( 

54 self, 

55 runner: "BaseRunner[Any]", 

56 *, 

57 ssl_context: Optional[SSLContext] = None, 

58 backlog: int = 128, 

59 ) -> None: 

60 if runner.server is None: 

61 raise RuntimeError("Call runner.setup() before making a site") 

62 self._runner = runner 

63 self._ssl_context = ssl_context 

64 self._backlog = backlog 

65 self._server: Optional[asyncio.Server] = None 

66 

67 @property 

68 @abstractmethod 

69 def name(self) -> str: 

70 """Return the name of the site (e.g. a URL).""" 

71 

72 @abstractmethod 

73 async def start(self) -> None: 

74 self._runner._reg_site(self) 

75 

76 async def stop(self) -> None: 

77 self._runner._check_site(self) 

78 if self._server is not None: # Maybe not started yet 

79 self._server.close() 

80 

81 self._runner._unreg_site(self) 

82 

83 

84class TCPSite(BaseSite): 

85 __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") 

86 

87 def __init__( 

88 self, 

89 runner: "BaseRunner[Any]", 

90 host: Optional[str] = None, 

91 port: Optional[int] = None, 

92 *, 

93 ssl_context: Optional[SSLContext] = None, 

94 backlog: int = 128, 

95 reuse_address: Optional[bool] = None, 

96 reuse_port: Optional[bool] = None, 

97 ) -> None: 

98 super().__init__( 

99 runner, 

100 ssl_context=ssl_context, 

101 backlog=backlog, 

102 ) 

103 self._host = host 

104 if port is None: 

105 port = 8443 if self._ssl_context else 8080 

106 self._port = port 

107 self._reuse_address = reuse_address 

108 self._reuse_port = reuse_port 

109 

110 @property 

111 def name(self) -> str: 

112 scheme = "https" if self._ssl_context else "http" 

113 host = "0.0.0.0" if not self._host else self._host 

114 return str(URL.build(scheme=scheme, host=host, port=self._port)) 

115 

116 async def start(self) -> None: 

117 await super().start() 

118 loop = asyncio.get_event_loop() 

119 server = self._runner.server 

120 assert server is not None 

121 self._server = await loop.create_server( 

122 server, 

123 self._host, 

124 self._port, 

125 ssl=self._ssl_context, 

126 backlog=self._backlog, 

127 reuse_address=self._reuse_address, 

128 reuse_port=self._reuse_port, 

129 ) 

130 

131 

132class UnixSite(BaseSite): 

133 __slots__ = ("_path",) 

134 

135 def __init__( 

136 self, 

137 runner: "BaseRunner[Any]", 

138 path: PathLike, 

139 *, 

140 ssl_context: Optional[SSLContext] = None, 

141 backlog: int = 128, 

142 ) -> None: 

143 super().__init__( 

144 runner, 

145 ssl_context=ssl_context, 

146 backlog=backlog, 

147 ) 

148 self._path = path 

149 

150 @property 

151 def name(self) -> str: 

152 scheme = "https" if self._ssl_context else "http" 

153 return f"{scheme}://unix:{self._path}:" 

154 

155 async def start(self) -> None: 

156 await super().start() 

157 loop = asyncio.get_event_loop() 

158 server = self._runner.server 

159 assert server is not None 

160 self._server = await loop.create_unix_server( 

161 server, 

162 self._path, 

163 ssl=self._ssl_context, 

164 backlog=self._backlog, 

165 ) 

166 

167 

168class NamedPipeSite(BaseSite): 

169 __slots__ = ("_path",) 

170 

171 def __init__(self, runner: "BaseRunner[Any]", path: str) -> None: 

172 loop = asyncio.get_event_loop() 

173 if not isinstance( 

174 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] 

175 ): 

176 raise RuntimeError( 

177 "Named Pipes only available in proactor loop under windows" 

178 ) 

179 super().__init__(runner) 

180 self._path = path 

181 

182 @property 

183 def name(self) -> str: 

184 return self._path 

185 

186 async def start(self) -> None: 

187 await super().start() 

188 loop = asyncio.get_event_loop() 

189 server = self._runner.server 

190 assert server is not None 

191 _server = await loop.start_serving_pipe( # type: ignore[attr-defined] 

192 server, self._path 

193 ) 

194 self._server = _server[0] 

195 

196 

197class SockSite(BaseSite): 

198 __slots__ = ("_sock", "_name") 

199 

200 def __init__( 

201 self, 

202 runner: "BaseRunner[Any]", 

203 sock: socket.socket, 

204 *, 

205 ssl_context: Optional[SSLContext] = None, 

206 backlog: int = 128, 

207 ) -> None: 

208 super().__init__( 

209 runner, 

210 ssl_context=ssl_context, 

211 backlog=backlog, 

212 ) 

213 self._sock = sock 

214 scheme = "https" if self._ssl_context else "http" 

215 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: 

216 name = f"{scheme}://unix:{sock.getsockname()}:" 

217 else: 

218 host, port = sock.getsockname()[:2] 

219 name = str(URL.build(scheme=scheme, host=host, port=port)) 

220 self._name = name 

221 

222 @property 

223 def name(self) -> str: 

224 return self._name 

225 

226 async def start(self) -> None: 

227 await super().start() 

228 loop = asyncio.get_event_loop() 

229 server = self._runner.server 

230 assert server is not None 

231 self._server = await loop.create_server( 

232 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog 

233 ) 

234 

235 

236class BaseRunner(ABC, Generic[_Request]): 

237 __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout") 

238 

239 def __init__( 

240 self, 

241 *, 

242 handle_signals: bool = False, 

243 shutdown_timeout: float = 60.0, 

244 **kwargs: Any, 

245 ) -> None: 

246 self._handle_signals = handle_signals 

247 self._kwargs = kwargs 

248 self._server: Optional[Server[_Request]] = None 

249 self._sites: List[BaseSite] = [] 

250 self._shutdown_timeout = shutdown_timeout 

251 

252 @property 

253 def server(self) -> Optional[Server[_Request]]: 

254 return self._server 

255 

256 @property 

257 def addresses(self) -> List[Any]: # type: ignore[misc] 

258 ret: List[Any] = [] 

259 for site in self._sites: 

260 server = site._server 

261 if server is not None: 

262 sockets = server.sockets 

263 if sockets is not None: 

264 for sock in sockets: 

265 ret.append(sock.getsockname()) 

266 return ret 

267 

268 @property 

269 def sites(self) -> Set[BaseSite]: 

270 return set(self._sites) 

271 

272 async def setup(self) -> None: 

273 loop = asyncio.get_event_loop() 

274 

275 if self._handle_signals: 

276 try: 

277 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) 

278 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) 

279 except NotImplementedError: 

280 # add_signal_handler is not implemented on Windows 

281 pass 

282 

283 self._server = await self._make_server() 

284 

285 @abstractmethod 

286 async def shutdown(self) -> None: 

287 """Call any shutdown hooks to help server close gracefully.""" 

288 

289 async def cleanup(self) -> None: 

290 # The loop over sites is intentional, an exception on gather() 

291 # leaves self._sites in unpredictable state. 

292 # The loop guarantees that a site is either deleted on success or 

293 # still present on failure 

294 for site in list(self._sites): 

295 await site.stop() 

296 

297 if self._server: # If setup succeeded 

298 # Yield to event loop to ensure incoming requests prior to stopping the sites 

299 # have all started to be handled before we proceed to close idle connections. 

300 await asyncio.sleep(0) 

301 self._server.pre_shutdown() 

302 await self.shutdown() 

303 await self._server.shutdown(self._shutdown_timeout) 

304 await self._cleanup_server() 

305 

306 self._server = None 

307 if self._handle_signals: 

308 loop = asyncio.get_running_loop() 

309 try: 

310 loop.remove_signal_handler(signal.SIGINT) 

311 loop.remove_signal_handler(signal.SIGTERM) 

312 except NotImplementedError: 

313 # remove_signal_handler is not implemented on Windows 

314 pass 

315 

316 @abstractmethod 

317 async def _make_server(self) -> Server[_Request]: 

318 """Return a new server for the runner to serve requests.""" 

319 

320 @abstractmethod 

321 async def _cleanup_server(self) -> None: 

322 """Run any cleanup steps after the server is shutdown.""" 

323 

324 def _reg_site(self, site: BaseSite) -> None: 

325 if site in self._sites: 

326 raise RuntimeError(f"Site {site} is already registered in runner {self}") 

327 self._sites.append(site) 

328 

329 def _check_site(self, site: BaseSite) -> None: 

330 if site not in self._sites: 

331 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

332 

333 def _unreg_site(self, site: BaseSite) -> None: 

334 if site not in self._sites: 

335 raise RuntimeError(f"Site {site} is not registered in runner {self}") 

336 self._sites.remove(site) 

337 

338 

339class ServerRunner(BaseRunner[BaseRequest]): 

340 """Low-level web server runner""" 

341 

342 __slots__ = ("_web_server",) 

343 

344 def __init__( 

345 self, 

346 web_server: Server[BaseRequest], 

347 *, 

348 handle_signals: bool = False, 

349 **kwargs: Any, 

350 ) -> None: 

351 super().__init__(handle_signals=handle_signals, **kwargs) 

352 self._web_server = web_server 

353 

354 async def shutdown(self) -> None: 

355 pass 

356 

357 async def _make_server(self) -> Server[BaseRequest]: 

358 return self._web_server 

359 

360 async def _cleanup_server(self) -> None: 

361 pass 

362 

363 

364class AppRunner(BaseRunner[Request]): 

365 """Web Application runner""" 

366 

367 __slots__ = ("_app",) 

368 

369 def __init__( 

370 self, 

371 app: Application, 

372 *, 

373 handle_signals: bool = False, 

374 access_log_class: Type[AbstractAccessLogger] = AccessLogger, 

375 **kwargs: Any, 

376 ) -> None: 

377 if not isinstance(app, Application): 

378 raise TypeError( 

379 "The first argument should be web.Application " 

380 "instance, got {!r}".format(app) 

381 ) 

382 kwargs["access_log_class"] = access_log_class 

383 

384 if app._handler_args: 

385 for k, v in app._handler_args.items(): 

386 kwargs[k] = v 

387 

388 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger): 

389 raise TypeError( 

390 "access_log_class must be subclass of " 

391 "aiohttp.abc.AbstractAccessLogger, got {}".format( 

392 kwargs["access_log_class"] 

393 ) 

394 ) 

395 

396 super().__init__(handle_signals=handle_signals, **kwargs) 

397 self._app = app 

398 

399 @property 

400 def app(self) -> Application: 

401 return self._app 

402 

403 async def shutdown(self) -> None: 

404 await self._app.shutdown() 

405 

406 async def _make_server(self) -> Server[Request]: 

407 self._app.on_startup.freeze() 

408 await self._app.startup() 

409 self._app.freeze() 

410 

411 return Server( 

412 self._app._handle, 

413 request_factory=self._make_request, 

414 **self._kwargs, 

415 ) 

416 

417 def _make_request( 

418 self, 

419 message: RawRequestMessage, 

420 payload: StreamReader, 

421 protocol: RequestHandler[Request], 

422 writer: AbstractStreamWriter, 

423 task: "asyncio.Task[None]", 

424 _cls: Type[Request] = Request, 

425 ) -> Request: 

426 loop = asyncio.get_running_loop() 

427 return _cls( 

428 message, 

429 payload, 

430 protocol, 

431 writer, 

432 task, 

433 loop, 

434 client_max_size=self.app._client_max_size, 

435 ) 

436 

437 async def _cleanup_server(self) -> None: 

438 await self._app.cleanup()