Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 40%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import signal
3import socket
4from abc import ABC, abstractmethod
5from typing import TYPE_CHECKING, Any, Generic, List, Optional, Set, Type, TypeVar
7from yarl import URL
9from .abc import AbstractAccessLogger, AbstractStreamWriter
10from .http_parser import RawRequestMessage
11from .streams import StreamReader
12from .typedefs import PathLike
13from .web_app import Application
14from .web_log import AccessLogger
15from .web_protocol import RequestHandler
16from .web_request import BaseRequest, Request
17from .web_server import Server
19if TYPE_CHECKING:
20 from ssl import SSLContext
21else:
22 try:
23 from ssl import SSLContext
24 except ImportError: # pragma: no cover
25 SSLContext = object # type: ignore[misc,assignment]
27__all__ = (
28 "BaseSite",
29 "TCPSite",
30 "UnixSite",
31 "NamedPipeSite",
32 "SockSite",
33 "BaseRunner",
34 "AppRunner",
35 "ServerRunner",
36 "GracefulExit",
37)
39_Request = TypeVar("_Request", bound=BaseRequest)
42class GracefulExit(SystemExit):
43 code = 1
46def _raise_graceful_exit() -> None:
47 raise GracefulExit()
50class BaseSite(ABC):
51 __slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
53 def __init__(
54 self,
55 runner: "BaseRunner[Any]",
56 *,
57 ssl_context: Optional[SSLContext] = None,
58 backlog: int = 128,
59 ) -> None:
60 if runner.server is None:
61 raise RuntimeError("Call runner.setup() before making a site")
62 self._runner = runner
63 self._ssl_context = ssl_context
64 self._backlog = backlog
65 self._server: Optional[asyncio.Server] = None
67 @property
68 @abstractmethod
69 def name(self) -> str:
70 """Return the name of the site (e.g. a URL)."""
72 @abstractmethod
73 async def start(self) -> None:
74 self._runner._reg_site(self)
76 async def stop(self) -> None:
77 self._runner._check_site(self)
78 if self._server is not None: # Maybe not started yet
79 self._server.close()
81 self._runner._unreg_site(self)
84class TCPSite(BaseSite):
85 __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
87 def __init__(
88 self,
89 runner: "BaseRunner[Any]",
90 host: Optional[str] = None,
91 port: Optional[int] = None,
92 *,
93 ssl_context: Optional[SSLContext] = None,
94 backlog: int = 128,
95 reuse_address: Optional[bool] = None,
96 reuse_port: Optional[bool] = None,
97 ) -> None:
98 super().__init__(
99 runner,
100 ssl_context=ssl_context,
101 backlog=backlog,
102 )
103 self._host = host
104 if port is None:
105 port = 8443 if self._ssl_context else 8080
106 self._port = port
107 self._reuse_address = reuse_address
108 self._reuse_port = reuse_port
110 @property
111 def name(self) -> str:
112 scheme = "https" if self._ssl_context else "http"
113 host = "0.0.0.0" if not self._host else self._host
114 return str(URL.build(scheme=scheme, host=host, port=self._port))
116 async def start(self) -> None:
117 await super().start()
118 loop = asyncio.get_event_loop()
119 server = self._runner.server
120 assert server is not None
121 self._server = await loop.create_server(
122 server,
123 self._host,
124 self._port,
125 ssl=self._ssl_context,
126 backlog=self._backlog,
127 reuse_address=self._reuse_address,
128 reuse_port=self._reuse_port,
129 )
132class UnixSite(BaseSite):
133 __slots__ = ("_path",)
135 def __init__(
136 self,
137 runner: "BaseRunner[Any]",
138 path: PathLike,
139 *,
140 ssl_context: Optional[SSLContext] = None,
141 backlog: int = 128,
142 ) -> None:
143 super().__init__(
144 runner,
145 ssl_context=ssl_context,
146 backlog=backlog,
147 )
148 self._path = path
150 @property
151 def name(self) -> str:
152 scheme = "https" if self._ssl_context else "http"
153 return f"{scheme}://unix:{self._path}:"
155 async def start(self) -> None:
156 await super().start()
157 loop = asyncio.get_event_loop()
158 server = self._runner.server
159 assert server is not None
160 self._server = await loop.create_unix_server(
161 server,
162 self._path,
163 ssl=self._ssl_context,
164 backlog=self._backlog,
165 )
168class NamedPipeSite(BaseSite):
169 __slots__ = ("_path",)
171 def __init__(self, runner: "BaseRunner[Any]", path: str) -> None:
172 loop = asyncio.get_event_loop()
173 if not isinstance(
174 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
175 ):
176 raise RuntimeError(
177 "Named Pipes only available in proactor loop under windows"
178 )
179 super().__init__(runner)
180 self._path = path
182 @property
183 def name(self) -> str:
184 return self._path
186 async def start(self) -> None:
187 await super().start()
188 loop = asyncio.get_event_loop()
189 server = self._runner.server
190 assert server is not None
191 _server = await loop.start_serving_pipe( # type: ignore[attr-defined]
192 server, self._path
193 )
194 self._server = _server[0]
197class SockSite(BaseSite):
198 __slots__ = ("_sock", "_name")
200 def __init__(
201 self,
202 runner: "BaseRunner[Any]",
203 sock: socket.socket,
204 *,
205 ssl_context: Optional[SSLContext] = None,
206 backlog: int = 128,
207 ) -> None:
208 super().__init__(
209 runner,
210 ssl_context=ssl_context,
211 backlog=backlog,
212 )
213 self._sock = sock
214 scheme = "https" if self._ssl_context else "http"
215 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
216 name = f"{scheme}://unix:{sock.getsockname()}:"
217 else:
218 host, port = sock.getsockname()[:2]
219 name = str(URL.build(scheme=scheme, host=host, port=port))
220 self._name = name
222 @property
223 def name(self) -> str:
224 return self._name
226 async def start(self) -> None:
227 await super().start()
228 loop = asyncio.get_event_loop()
229 server = self._runner.server
230 assert server is not None
231 self._server = await loop.create_server(
232 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
233 )
236class BaseRunner(ABC, Generic[_Request]):
237 __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
239 def __init__(
240 self,
241 *,
242 handle_signals: bool = False,
243 shutdown_timeout: float = 60.0,
244 **kwargs: Any,
245 ) -> None:
246 self._handle_signals = handle_signals
247 self._kwargs = kwargs
248 self._server: Optional[Server[_Request]] = None
249 self._sites: List[BaseSite] = []
250 self._shutdown_timeout = shutdown_timeout
252 @property
253 def server(self) -> Optional[Server[_Request]]:
254 return self._server
256 @property
257 def addresses(self) -> List[Any]: # type: ignore[misc]
258 ret: List[Any] = []
259 for site in self._sites:
260 server = site._server
261 if server is not None:
262 sockets = server.sockets
263 if sockets is not None:
264 for sock in sockets:
265 ret.append(sock.getsockname())
266 return ret
268 @property
269 def sites(self) -> Set[BaseSite]:
270 return set(self._sites)
272 async def setup(self) -> None:
273 loop = asyncio.get_event_loop()
275 if self._handle_signals:
276 try:
277 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
278 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
279 except NotImplementedError:
280 # add_signal_handler is not implemented on Windows
281 pass
283 self._server = await self._make_server()
285 @abstractmethod
286 async def shutdown(self) -> None:
287 """Call any shutdown hooks to help server close gracefully."""
289 async def cleanup(self) -> None:
290 # The loop over sites is intentional, an exception on gather()
291 # leaves self._sites in unpredictable state.
292 # The loop guarantees that a site is either deleted on success or
293 # still present on failure
294 for site in list(self._sites):
295 await site.stop()
297 if self._server: # If setup succeeded
298 # Yield to event loop to ensure incoming requests prior to stopping the sites
299 # have all started to be handled before we proceed to close idle connections.
300 await asyncio.sleep(0)
301 self._server.pre_shutdown()
302 await self.shutdown()
303 await self._server.shutdown(self._shutdown_timeout)
304 await self._cleanup_server()
306 self._server = None
307 if self._handle_signals:
308 loop = asyncio.get_running_loop()
309 try:
310 loop.remove_signal_handler(signal.SIGINT)
311 loop.remove_signal_handler(signal.SIGTERM)
312 except NotImplementedError:
313 # remove_signal_handler is not implemented on Windows
314 pass
316 @abstractmethod
317 async def _make_server(self) -> Server[_Request]:
318 """Return a new server for the runner to serve requests."""
320 @abstractmethod
321 async def _cleanup_server(self) -> None:
322 """Run any cleanup steps after the server is shutdown."""
324 def _reg_site(self, site: BaseSite) -> None:
325 if site in self._sites:
326 raise RuntimeError(f"Site {site} is already registered in runner {self}")
327 self._sites.append(site)
329 def _check_site(self, site: BaseSite) -> None:
330 if site not in self._sites:
331 raise RuntimeError(f"Site {site} is not registered in runner {self}")
333 def _unreg_site(self, site: BaseSite) -> None:
334 if site not in self._sites:
335 raise RuntimeError(f"Site {site} is not registered in runner {self}")
336 self._sites.remove(site)
339class ServerRunner(BaseRunner[BaseRequest]):
340 """Low-level web server runner"""
342 __slots__ = ("_web_server",)
344 def __init__(
345 self,
346 web_server: Server[BaseRequest],
347 *,
348 handle_signals: bool = False,
349 **kwargs: Any,
350 ) -> None:
351 super().__init__(handle_signals=handle_signals, **kwargs)
352 self._web_server = web_server
354 async def shutdown(self) -> None:
355 pass
357 async def _make_server(self) -> Server[BaseRequest]:
358 return self._web_server
360 async def _cleanup_server(self) -> None:
361 pass
364class AppRunner(BaseRunner[Request]):
365 """Web Application runner"""
367 __slots__ = ("_app",)
369 def __init__(
370 self,
371 app: Application,
372 *,
373 handle_signals: bool = False,
374 access_log_class: Type[AbstractAccessLogger] = AccessLogger,
375 **kwargs: Any,
376 ) -> None:
377 if not isinstance(app, Application):
378 raise TypeError(
379 "The first argument should be web.Application "
380 "instance, got {!r}".format(app)
381 )
382 kwargs["access_log_class"] = access_log_class
384 if app._handler_args:
385 for k, v in app._handler_args.items():
386 kwargs[k] = v
388 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger):
389 raise TypeError(
390 "access_log_class must be subclass of "
391 "aiohttp.abc.AbstractAccessLogger, got {}".format(
392 kwargs["access_log_class"]
393 )
394 )
396 super().__init__(handle_signals=handle_signals, **kwargs)
397 self._app = app
399 @property
400 def app(self) -> Application:
401 return self._app
403 async def shutdown(self) -> None:
404 await self._app.shutdown()
406 async def _make_server(self) -> Server[Request]:
407 self._app.on_startup.freeze()
408 await self._app.startup()
409 self._app.freeze()
411 return Server(
412 self._app._handle,
413 request_factory=self._make_request,
414 **self._kwargs,
415 )
417 def _make_request(
418 self,
419 message: RawRequestMessage,
420 payload: StreamReader,
421 protocol: RequestHandler[Request],
422 writer: AbstractStreamWriter,
423 task: "asyncio.Task[None]",
424 _cls: Type[Request] = Request,
425 ) -> Request:
426 loop = asyncio.get_running_loop()
427 return _cls(
428 message,
429 payload,
430 protocol,
431 writer,
432 task,
433 loop,
434 client_max_size=self.app._client_max_size,
435 )
437 async def _cleanup_server(self) -> None:
438 await self._app.cleanup()