Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 40%
223 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-26 06:16 +0000
1import asyncio
2import signal
3import socket
4from abc import ABC, abstractmethod
5from typing import Any, Awaitable, Callable, List, Optional, Set, Type
7from yarl import URL
9from .abc import AbstractAccessLogger, AbstractStreamWriter
10from .http_parser import RawRequestMessage
11from .streams import StreamReader
12from .typedefs import PathLike
13from .web_app import Application
14from .web_log import AccessLogger
15from .web_protocol import RequestHandler
16from .web_request import Request
17from .web_server import Server
19try:
20 from ssl import SSLContext
21except ImportError:
22 SSLContext = object # type: ignore[misc,assignment]
25__all__ = (
26 "BaseSite",
27 "TCPSite",
28 "UnixSite",
29 "NamedPipeSite",
30 "SockSite",
31 "BaseRunner",
32 "AppRunner",
33 "ServerRunner",
34 "GracefulExit",
35)
38class GracefulExit(SystemExit):
39 code = 1
42def _raise_graceful_exit() -> None:
43 raise GracefulExit()
46class BaseSite(ABC):
47 __slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
49 def __init__(
50 self,
51 runner: "BaseRunner",
52 *,
53 ssl_context: Optional[SSLContext] = None,
54 backlog: int = 128,
55 ) -> None:
56 if runner.server is None:
57 raise RuntimeError("Call runner.setup() before making a site")
58 self._runner = runner
59 self._ssl_context = ssl_context
60 self._backlog = backlog
61 self._server: Optional[asyncio.AbstractServer] = None
63 @property
64 @abstractmethod
65 def name(self) -> str:
66 pass # pragma: no cover
68 @abstractmethod
69 async def start(self) -> None:
70 self._runner._reg_site(self)
72 async def stop(self) -> None:
73 self._runner._check_site(self)
74 if self._server is not None: # Maybe not started yet
75 self._server.close()
77 self._runner._unreg_site(self)
80class TCPSite(BaseSite):
81 __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
83 def __init__(
84 self,
85 runner: "BaseRunner",
86 host: Optional[str] = None,
87 port: Optional[int] = None,
88 *,
89 ssl_context: Optional[SSLContext] = None,
90 backlog: int = 128,
91 reuse_address: Optional[bool] = None,
92 reuse_port: Optional[bool] = None,
93 ) -> None:
94 super().__init__(
95 runner,
96 ssl_context=ssl_context,
97 backlog=backlog,
98 )
99 self._host = host
100 if port is None:
101 port = 8443 if self._ssl_context else 8080
102 self._port = port
103 self._reuse_address = reuse_address
104 self._reuse_port = reuse_port
106 @property
107 def name(self) -> str:
108 scheme = "https" if self._ssl_context else "http"
109 host = "0.0.0.0" if self._host is None else self._host
110 return str(URL.build(scheme=scheme, host=host, port=self._port))
112 async def start(self) -> None:
113 await super().start()
114 loop = asyncio.get_event_loop()
115 server = self._runner.server
116 assert server is not None
117 self._server = await loop.create_server(
118 server,
119 self._host,
120 self._port,
121 ssl=self._ssl_context,
122 backlog=self._backlog,
123 reuse_address=self._reuse_address,
124 reuse_port=self._reuse_port,
125 )
128class UnixSite(BaseSite):
129 __slots__ = ("_path",)
131 def __init__(
132 self,
133 runner: "BaseRunner",
134 path: PathLike,
135 *,
136 ssl_context: Optional[SSLContext] = None,
137 backlog: int = 128,
138 ) -> None:
139 super().__init__(
140 runner,
141 ssl_context=ssl_context,
142 backlog=backlog,
143 )
144 self._path = path
146 @property
147 def name(self) -> str:
148 scheme = "https" if self._ssl_context else "http"
149 return f"{scheme}://unix:{self._path}:"
151 async def start(self) -> None:
152 await super().start()
153 loop = asyncio.get_event_loop()
154 server = self._runner.server
155 assert server is not None
156 self._server = await loop.create_unix_server(
157 server,
158 self._path,
159 ssl=self._ssl_context,
160 backlog=self._backlog,
161 )
164class NamedPipeSite(BaseSite):
165 __slots__ = ("_path",)
167 def __init__(self, runner: "BaseRunner", path: str) -> None:
168 loop = asyncio.get_event_loop()
169 if not isinstance(
170 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
171 ):
172 raise RuntimeError(
173 "Named Pipes only available in proactor" "loop under windows"
174 )
175 super().__init__(runner)
176 self._path = path
178 @property
179 def name(self) -> str:
180 return self._path
182 async def start(self) -> None:
183 await super().start()
184 loop = asyncio.get_event_loop()
185 server = self._runner.server
186 assert server is not None
187 _server = await loop.start_serving_pipe( # type: ignore[attr-defined]
188 server, self._path
189 )
190 self._server = _server[0]
193class SockSite(BaseSite):
194 __slots__ = ("_sock", "_name")
196 def __init__(
197 self,
198 runner: "BaseRunner",
199 sock: socket.socket,
200 *,
201 ssl_context: Optional[SSLContext] = None,
202 backlog: int = 128,
203 ) -> None:
204 super().__init__(
205 runner,
206 ssl_context=ssl_context,
207 backlog=backlog,
208 )
209 self._sock = sock
210 scheme = "https" if self._ssl_context else "http"
211 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
212 name = f"{scheme}://unix:{sock.getsockname()}:"
213 else:
214 host, port = sock.getsockname()[:2]
215 name = str(URL.build(scheme=scheme, host=host, port=port))
216 self._name = name
218 @property
219 def name(self) -> str:
220 return self._name
222 async def start(self) -> None:
223 await super().start()
224 loop = asyncio.get_event_loop()
225 server = self._runner.server
226 assert server is not None
227 self._server = await loop.create_server(
228 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
229 )
232class BaseRunner(ABC):
233 __slots__ = (
234 "shutdown_callback",
235 "_handle_signals",
236 "_kwargs",
237 "_server",
238 "_sites",
239 "_shutdown_timeout",
240 )
242 def __init__(
243 self,
244 *,
245 handle_signals: bool = False,
246 shutdown_timeout: float = 60.0,
247 **kwargs: Any,
248 ) -> None:
249 self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None
250 self._handle_signals = handle_signals
251 self._kwargs = kwargs
252 self._server: Optional[Server] = None
253 self._sites: List[BaseSite] = []
254 self._shutdown_timeout = shutdown_timeout
256 @property
257 def server(self) -> Optional[Server]:
258 return self._server
260 @property
261 def addresses(self) -> List[Any]:
262 ret: List[Any] = []
263 for site in self._sites:
264 server = site._server
265 if server is not None:
266 sockets = server.sockets # type: ignore[attr-defined]
267 if sockets is not None:
268 for sock in sockets:
269 ret.append(sock.getsockname())
270 return ret
272 @property
273 def sites(self) -> Set[BaseSite]:
274 return set(self._sites)
276 async def setup(self) -> None:
277 loop = asyncio.get_event_loop()
279 if self._handle_signals:
280 try:
281 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
282 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
283 except NotImplementedError: # pragma: no cover
284 # add_signal_handler is not implemented on Windows
285 pass
287 self._server = await self._make_server()
289 @abstractmethod
290 async def shutdown(self) -> None:
291 """Call any shutdown hooks to help server close gracefully."""
293 async def cleanup(self) -> None:
294 # The loop over sites is intentional, an exception on gather()
295 # leaves self._sites in unpredictable state.
296 # The loop guarantees that a site is either deleted on success or
297 # still present on failure
298 for site in list(self._sites):
299 await site.stop()
301 if self._server: # If setup succeeded
302 self._server.pre_shutdown()
303 await self.shutdown()
305 if self.shutdown_callback:
306 await self.shutdown_callback()
308 await self._server.shutdown(self._shutdown_timeout)
309 await self._cleanup_server()
311 self._server = None
312 if self._handle_signals:
313 loop = asyncio.get_running_loop()
314 try:
315 loop.remove_signal_handler(signal.SIGINT)
316 loop.remove_signal_handler(signal.SIGTERM)
317 except NotImplementedError: # pragma: no cover
318 # remove_signal_handler is not implemented on Windows
319 pass
321 @abstractmethod
322 async def _make_server(self) -> Server:
323 pass # pragma: no cover
325 @abstractmethod
326 async def _cleanup_server(self) -> None:
327 pass # pragma: no cover
329 def _reg_site(self, site: BaseSite) -> None:
330 if site in self._sites:
331 raise RuntimeError(f"Site {site} is already registered in runner {self}")
332 self._sites.append(site)
334 def _check_site(self, site: BaseSite) -> None:
335 if site not in self._sites:
336 raise RuntimeError(f"Site {site} is not registered in runner {self}")
338 def _unreg_site(self, site: BaseSite) -> None:
339 if site not in self._sites:
340 raise RuntimeError(f"Site {site} is not registered in runner {self}")
341 self._sites.remove(site)
344class ServerRunner(BaseRunner):
345 """Low-level web server runner"""
347 __slots__ = ("_web_server",)
349 def __init__(
350 self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any
351 ) -> None:
352 super().__init__(handle_signals=handle_signals, **kwargs)
353 self._web_server = web_server
355 async def shutdown(self) -> None:
356 pass
358 async def _make_server(self) -> Server:
359 return self._web_server
361 async def _cleanup_server(self) -> None:
362 pass
365class AppRunner(BaseRunner):
366 """Web Application runner"""
368 __slots__ = ("_app",)
370 def __init__(
371 self,
372 app: Application,
373 *,
374 handle_signals: bool = False,
375 access_log_class: Type[AbstractAccessLogger] = AccessLogger,
376 **kwargs: Any,
377 ) -> None:
378 if not isinstance(app, Application):
379 raise TypeError(
380 "The first argument should be web.Application "
381 "instance, got {!r}".format(app)
382 )
383 kwargs["access_log_class"] = access_log_class
385 if app._handler_args:
386 for k, v in app._handler_args.items():
387 kwargs[k] = v
389 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger):
390 raise TypeError(
391 "access_log_class must be subclass of "
392 "aiohttp.abc.AbstractAccessLogger, got {}".format(
393 kwargs["access_log_class"]
394 )
395 )
397 super().__init__(handle_signals=handle_signals, **kwargs)
398 self._app = app
400 @property
401 def app(self) -> Application:
402 return self._app
404 async def shutdown(self) -> None:
405 await self._app.shutdown()
407 async def _make_server(self) -> Server:
408 self._app.on_startup.freeze()
409 await self._app.startup()
410 self._app.freeze()
412 return Server(
413 self._app._handle, # type: ignore[arg-type]
414 request_factory=self._make_request,
415 **self._kwargs,
416 )
418 def _make_request(
419 self,
420 message: RawRequestMessage,
421 payload: StreamReader,
422 protocol: RequestHandler,
423 writer: AbstractStreamWriter,
424 task: "asyncio.Task[None]",
425 _cls: Type[Request] = Request,
426 ) -> Request:
427 loop = asyncio.get_running_loop()
428 return _cls(
429 message,
430 payload,
431 protocol,
432 writer,
433 task,
434 loop,
435 client_max_size=self.app._client_max_size,
436 )
438 async def _cleanup_server(self) -> None:
439 await self._app.cleanup()