Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 39%
233 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:52 +0000
1import asyncio
2import signal
3import socket
4from abc import ABC, abstractmethod
5from contextlib import suppress
6from typing import Any, List, Optional, Set, Type
8from yarl import URL
10from .abc import AbstractAccessLogger, AbstractStreamWriter
11from .http_parser import RawRequestMessage
12from .streams import StreamReader
13from .typedefs import PathLike
14from .web_app import Application
15from .web_log import AccessLogger
16from .web_protocol import RequestHandler
17from .web_request import Request
18from .web_server import Server
20try:
21 from ssl import SSLContext
22except ImportError:
23 SSLContext = object # type: ignore[misc,assignment]
26__all__ = (
27 "BaseSite",
28 "TCPSite",
29 "UnixSite",
30 "NamedPipeSite",
31 "SockSite",
32 "BaseRunner",
33 "AppRunner",
34 "ServerRunner",
35 "GracefulExit",
36)
39class GracefulExit(SystemExit):
40 code = 1
43def _raise_graceful_exit() -> None:
44 raise GracefulExit()
47class BaseSite(ABC):
48 __slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server")
50 def __init__(
51 self,
52 runner: "BaseRunner",
53 *,
54 shutdown_timeout: float = 60.0,
55 ssl_context: Optional[SSLContext] = None,
56 backlog: int = 128,
57 ) -> None:
58 if runner.server is None:
59 raise RuntimeError("Call runner.setup() before making a site")
60 self._runner = runner
61 self._shutdown_timeout = shutdown_timeout
62 self._ssl_context = ssl_context
63 self._backlog = backlog
64 self._server: Optional[asyncio.AbstractServer] = None
66 @property
67 @abstractmethod
68 def name(self) -> str:
69 pass # pragma: no cover
71 @abstractmethod
72 async def start(self) -> None:
73 self._runner._reg_site(self)
75 async def stop(self) -> None:
76 self._runner._check_site(self)
77 if self._server is None:
78 self._runner._unreg_site(self)
79 return # not started yet
80 self._server.close()
81 # named pipes do not have wait_closed property
82 if hasattr(self._server, "wait_closed"):
83 await self._server.wait_closed()
85 # Wait for pending tasks for a given time limit.
86 with suppress(asyncio.TimeoutError):
87 await asyncio.wait_for(
88 self._wait(asyncio.current_task()), timeout=self._shutdown_timeout
89 )
91 await self._runner.shutdown()
92 assert self._runner.server
93 await self._runner.server.shutdown(self._shutdown_timeout)
94 self._runner._unreg_site(self)
96 async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None:
97 exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task}
98 # TODO(PY38): while tasks := asyncio.all_tasks() - exclude:
99 tasks = asyncio.all_tasks() - exclude
100 while tasks:
101 await asyncio.wait(tasks)
102 tasks = asyncio.all_tasks() - exclude
105class TCPSite(BaseSite):
106 __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
108 def __init__(
109 self,
110 runner: "BaseRunner",
111 host: Optional[str] = None,
112 port: Optional[int] = None,
113 *,
114 shutdown_timeout: float = 60.0,
115 ssl_context: Optional[SSLContext] = None,
116 backlog: int = 128,
117 reuse_address: Optional[bool] = None,
118 reuse_port: Optional[bool] = None,
119 ) -> None:
120 super().__init__(
121 runner,
122 shutdown_timeout=shutdown_timeout,
123 ssl_context=ssl_context,
124 backlog=backlog,
125 )
126 self._host = host
127 if port is None:
128 port = 8443 if self._ssl_context else 8080
129 self._port = port
130 self._reuse_address = reuse_address
131 self._reuse_port = reuse_port
133 @property
134 def name(self) -> str:
135 scheme = "https" if self._ssl_context else "http"
136 host = "0.0.0.0" if self._host is None else self._host
137 return str(URL.build(scheme=scheme, host=host, port=self._port))
139 async def start(self) -> None:
140 await super().start()
141 loop = asyncio.get_event_loop()
142 server = self._runner.server
143 assert server is not None
144 self._server = await loop.create_server(
145 server,
146 self._host,
147 self._port,
148 ssl=self._ssl_context,
149 backlog=self._backlog,
150 reuse_address=self._reuse_address,
151 reuse_port=self._reuse_port,
152 )
155class UnixSite(BaseSite):
156 __slots__ = ("_path",)
158 def __init__(
159 self,
160 runner: "BaseRunner",
161 path: PathLike,
162 *,
163 shutdown_timeout: float = 60.0,
164 ssl_context: Optional[SSLContext] = None,
165 backlog: int = 128,
166 ) -> None:
167 super().__init__(
168 runner,
169 shutdown_timeout=shutdown_timeout,
170 ssl_context=ssl_context,
171 backlog=backlog,
172 )
173 self._path = path
175 @property
176 def name(self) -> str:
177 scheme = "https" if self._ssl_context else "http"
178 return f"{scheme}://unix:{self._path}:"
180 async def start(self) -> None:
181 await super().start()
182 loop = asyncio.get_event_loop()
183 server = self._runner.server
184 assert server is not None
185 self._server = await loop.create_unix_server(
186 server,
187 self._path,
188 ssl=self._ssl_context,
189 backlog=self._backlog,
190 )
193class NamedPipeSite(BaseSite):
194 __slots__ = ("_path",)
196 def __init__(
197 self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0
198 ) -> None:
199 loop = asyncio.get_event_loop()
200 if not isinstance(
201 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
202 ):
203 raise RuntimeError(
204 "Named Pipes only available in proactor" "loop under windows"
205 )
206 super().__init__(runner, shutdown_timeout=shutdown_timeout)
207 self._path = path
209 @property
210 def name(self) -> str:
211 return self._path
213 async def start(self) -> None:
214 await super().start()
215 loop = asyncio.get_event_loop()
216 server = self._runner.server
217 assert server is not None
218 _server = await loop.start_serving_pipe( # type: ignore[attr-defined]
219 server, self._path
220 )
221 self._server = _server[0]
224class SockSite(BaseSite):
225 __slots__ = ("_sock", "_name")
227 def __init__(
228 self,
229 runner: "BaseRunner",
230 sock: socket.socket,
231 *,
232 shutdown_timeout: float = 60.0,
233 ssl_context: Optional[SSLContext] = None,
234 backlog: int = 128,
235 ) -> None:
236 super().__init__(
237 runner,
238 shutdown_timeout=shutdown_timeout,
239 ssl_context=ssl_context,
240 backlog=backlog,
241 )
242 self._sock = sock
243 scheme = "https" if self._ssl_context else "http"
244 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
245 name = f"{scheme}://unix:{sock.getsockname()}:"
246 else:
247 host, port = sock.getsockname()[:2]
248 name = str(URL.build(scheme=scheme, host=host, port=port))
249 self._name = name
251 @property
252 def name(self) -> str:
253 return self._name
255 async def start(self) -> None:
256 await super().start()
257 loop = asyncio.get_event_loop()
258 server = self._runner.server
259 assert server is not None
260 self._server = await loop.create_server(
261 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
262 )
265class BaseRunner(ABC):
266 __slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites")
268 def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
269 self._handle_signals = handle_signals
270 self._kwargs = kwargs
271 self._server: Optional[Server] = None
272 self._sites: List[BaseSite] = []
274 @property
275 def server(self) -> Optional[Server]:
276 return self._server
278 @property
279 def addresses(self) -> List[Any]:
280 ret: List[Any] = []
281 for site in self._sites:
282 server = site._server
283 if server is not None:
284 sockets = server.sockets # type: ignore[attr-defined]
285 if sockets is not None:
286 for sock in sockets:
287 ret.append(sock.getsockname())
288 return ret
290 @property
291 def sites(self) -> Set[BaseSite]:
292 return set(self._sites)
294 async def setup(self) -> None:
295 loop = asyncio.get_event_loop()
297 if self._handle_signals:
298 try:
299 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
300 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
301 except NotImplementedError: # pragma: no cover
302 # add_signal_handler is not implemented on Windows
303 pass
305 self._server = await self._make_server()
306 # On shutdown we want to avoid waiting on tasks which run forever.
307 # It's very likely that all tasks which run forever will have been created by
308 # the time we have completed the application startup (in self._make_server()),
309 # so we just record all running tasks here and exclude them later.
310 self.starting_tasks = asyncio.all_tasks()
312 @abstractmethod
313 async def shutdown(self) -> None:
314 pass # pragma: no cover
316 async def cleanup(self) -> None:
317 loop = asyncio.get_event_loop()
319 # The loop over sites is intentional, an exception on gather()
320 # leaves self._sites in unpredictable state.
321 # The loop guarantees that a site is either deleted on success or
322 # still present on failure
323 for site in list(self._sites):
324 await site.stop()
325 await self._cleanup_server()
326 self._server = None
327 if self._handle_signals:
328 try:
329 loop.remove_signal_handler(signal.SIGINT)
330 loop.remove_signal_handler(signal.SIGTERM)
331 except NotImplementedError: # pragma: no cover
332 # remove_signal_handler is not implemented on Windows
333 pass
335 @abstractmethod
336 async def _make_server(self) -> Server:
337 pass # pragma: no cover
339 @abstractmethod
340 async def _cleanup_server(self) -> None:
341 pass # pragma: no cover
343 def _reg_site(self, site: BaseSite) -> None:
344 if site in self._sites:
345 raise RuntimeError(f"Site {site} is already registered in runner {self}")
346 self._sites.append(site)
348 def _check_site(self, site: BaseSite) -> None:
349 if site not in self._sites:
350 raise RuntimeError(f"Site {site} is not registered in runner {self}")
352 def _unreg_site(self, site: BaseSite) -> None:
353 if site not in self._sites:
354 raise RuntimeError(f"Site {site} is not registered in runner {self}")
355 self._sites.remove(site)
358class ServerRunner(BaseRunner):
359 """Low-level web server runner"""
361 __slots__ = ("_web_server",)
363 def __init__(
364 self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any
365 ) -> None:
366 super().__init__(handle_signals=handle_signals, **kwargs)
367 self._web_server = web_server
369 async def shutdown(self) -> None:
370 pass
372 async def _make_server(self) -> Server:
373 return self._web_server
375 async def _cleanup_server(self) -> None:
376 pass
379class AppRunner(BaseRunner):
380 """Web Application runner"""
382 __slots__ = ("_app",)
384 def __init__(
385 self,
386 app: Application,
387 *,
388 handle_signals: bool = False,
389 access_log_class: Type[AbstractAccessLogger] = AccessLogger,
390 **kwargs: Any,
391 ) -> None:
392 if not isinstance(app, Application):
393 raise TypeError(
394 "The first argument should be web.Application "
395 "instance, got {!r}".format(app)
396 )
397 kwargs["access_log_class"] = access_log_class
399 if app._handler_args:
400 for k, v in app._handler_args.items():
401 kwargs[k] = v
403 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger):
404 raise TypeError(
405 "access_log_class must be subclass of "
406 "aiohttp.abc.AbstractAccessLogger, got {}".format(
407 kwargs["access_log_class"]
408 )
409 )
411 super().__init__(handle_signals=handle_signals, **kwargs)
412 self._app = app
414 @property
415 def app(self) -> Application:
416 return self._app
418 async def shutdown(self) -> None:
419 await self._app.shutdown()
421 async def _make_server(self) -> Server:
422 self._app.on_startup.freeze()
423 await self._app.startup()
424 self._app.freeze()
426 return Server(
427 self._app._handle, # type: ignore[arg-type]
428 request_factory=self._make_request,
429 **self._kwargs,
430 )
432 def _make_request(
433 self,
434 message: RawRequestMessage,
435 payload: StreamReader,
436 protocol: RequestHandler,
437 writer: AbstractStreamWriter,
438 task: "asyncio.Task[None]",
439 _cls: Type[Request] = Request,
440 ) -> Request:
441 loop = asyncio.get_running_loop()
442 return _cls(
443 message,
444 payload,
445 protocol,
446 writer,
447 task,
448 loop,
449 client_max_size=self.app._client_max_size,
450 )
452 async def _cleanup_server(self) -> None:
453 await self._app.cleanup()