Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/web_runner.py: 39%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import signal
3import socket
4from abc import ABC, abstractmethod
5from typing import Any, Generic, TypeVar
7from yarl import URL
9from .abc import AbstractAccessLogger, AbstractStreamWriter
10from .http_parser import RawRequestMessage
11from .streams import StreamReader
12from .typedefs import PathLike
13from .web_app import Application
14from .web_log import AccessLogger
15from .web_protocol import RequestHandler
16from .web_request import BaseRequest, Request
17from .web_server import Server
19try:
20 from ssl import SSLContext
21except ImportError: # pragma: no cover
22 SSLContext = object # type: ignore[misc,assignment]
24__all__ = (
25 "BaseSite",
26 "TCPSite",
27 "UnixSite",
28 "NamedPipeSite",
29 "SockSite",
30 "BaseRunner",
31 "AppRunner",
32 "ServerRunner",
33 "GracefulExit",
34)
36_Request = TypeVar("_Request", bound=BaseRequest)
39class GracefulExit(SystemExit):
40 code = 1
43def _raise_graceful_exit() -> None:
44 raise GracefulExit()
47class BaseSite(ABC):
48 __slots__ = ("_runner", "_ssl_context", "_backlog", "_server")
50 def __init__(
51 self,
52 runner: "BaseRunner[Any]",
53 *,
54 ssl_context: SSLContext | None = None,
55 backlog: int = 128,
56 ) -> None:
57 if runner.server is None:
58 raise RuntimeError("Call runner.setup() before making a site")
59 self._runner = runner
60 self._ssl_context = ssl_context
61 self._backlog = backlog
62 self._server: asyncio.Server | None = None
64 @property
65 @abstractmethod
66 def name(self) -> str:
67 """Return the name of the site (e.g. a URL)."""
69 @abstractmethod
70 async def start(self) -> None:
71 self._runner._reg_site(self)
73 async def stop(self) -> None:
74 self._runner._check_site(self)
75 if self._server is not None: # Maybe not started yet
76 self._server.close()
78 self._runner._unreg_site(self)
81class TCPSite(BaseSite):
82 __slots__ = ("_host", "_port", "_bound_port", "_reuse_address", "_reuse_port")
84 def __init__(
85 self,
86 runner: "BaseRunner[Any]",
87 host: str | None = None,
88 port: int | None = None,
89 *,
90 ssl_context: SSLContext | None = None,
91 backlog: int = 128,
92 reuse_address: bool | None = None,
93 reuse_port: bool | None = None,
94 ) -> None:
95 super().__init__(
96 runner,
97 ssl_context=ssl_context,
98 backlog=backlog,
99 )
100 self._host = host
101 if port is None:
102 port = 8443 if self._ssl_context else 8080
103 self._port = port
104 self._bound_port: int | None = None
105 self._reuse_address = reuse_address
106 self._reuse_port = reuse_port
108 @property
109 def port(self) -> int:
110 """The port the server is listening on.
112 If the server hasn't been started yet, this returns the requested port
113 (which might be 0 for a dynamic port).
114 After the server starts, it returns the actual bound port. This is
115 especially useful when port=0 was requested, as it allows retrieving the
116 dynamically assigned port after the site has started.
117 """
118 if self._bound_port is not None:
119 return self._bound_port
120 return self._port
122 @property
123 def name(self) -> str:
124 scheme = "https" if self._ssl_context else "http"
125 host = "0.0.0.0" if not self._host else self._host
126 return str(URL.build(scheme=scheme, host=host, port=self.port))
128 async def start(self) -> None:
129 await super().start()
130 loop = asyncio.get_event_loop()
131 server = self._runner.server
132 assert server is not None
133 self._server = await loop.create_server(
134 server,
135 self._host,
136 self._port,
137 ssl=self._ssl_context,
138 backlog=self._backlog,
139 reuse_address=self._reuse_address,
140 reuse_port=self._reuse_port,
141 )
142 if self._server.sockets:
143 self._bound_port = self._server.sockets[0].getsockname()[1]
144 else:
145 self._bound_port = self._port
148class UnixSite(BaseSite):
149 __slots__ = ("_path",)
151 def __init__(
152 self,
153 runner: "BaseRunner[Any]",
154 path: PathLike,
155 *,
156 ssl_context: SSLContext | None = None,
157 backlog: int = 128,
158 ) -> None:
159 super().__init__(
160 runner,
161 ssl_context=ssl_context,
162 backlog=backlog,
163 )
164 self._path = path
166 @property
167 def name(self) -> str:
168 scheme = "https" if self._ssl_context else "http"
169 return f"{scheme}://unix:{self._path}:"
171 async def start(self) -> None:
172 await super().start()
173 loop = asyncio.get_event_loop()
174 server = self._runner.server
175 assert server is not None
176 self._server = await loop.create_unix_server(
177 server,
178 self._path,
179 ssl=self._ssl_context,
180 backlog=self._backlog,
181 )
184class NamedPipeSite(BaseSite):
185 __slots__ = ("_path",)
187 def __init__(self, runner: "BaseRunner[Any]", path: str) -> None:
188 loop = asyncio.get_event_loop()
189 if not isinstance(
190 loop, asyncio.ProactorEventLoop # type: ignore[attr-defined]
191 ):
192 raise RuntimeError(
193 "Named Pipes only available in proactor loop under windows"
194 )
195 super().__init__(runner)
196 self._path = path
198 @property
199 def name(self) -> str:
200 return self._path
202 async def start(self) -> None:
203 await super().start()
204 loop = asyncio.get_event_loop()
205 server = self._runner.server
206 assert server is not None
207 _server = await loop.start_serving_pipe( # type: ignore[attr-defined]
208 server, self._path
209 )
210 self._server = _server[0]
213class SockSite(BaseSite):
214 __slots__ = ("_sock", "_name")
216 def __init__(
217 self,
218 runner: "BaseRunner[Any]",
219 sock: socket.socket,
220 *,
221 ssl_context: SSLContext | None = None,
222 backlog: int = 128,
223 ) -> None:
224 super().__init__(
225 runner,
226 ssl_context=ssl_context,
227 backlog=backlog,
228 )
229 self._sock = sock
230 scheme = "https" if self._ssl_context else "http"
231 if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
232 name = f"{scheme}://unix:{sock.getsockname()}:"
233 else:
234 host, port = sock.getsockname()[:2]
235 name = str(URL.build(scheme=scheme, host=host, port=port))
236 self._name = name
238 @property
239 def name(self) -> str:
240 return self._name
242 async def start(self) -> None:
243 await super().start()
244 loop = asyncio.get_event_loop()
245 server = self._runner.server
246 assert server is not None
247 self._server = await loop.create_server(
248 server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
249 )
252class BaseRunner(ABC, Generic[_Request]):
253 __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites", "_shutdown_timeout")
255 def __init__(
256 self,
257 *,
258 handle_signals: bool = False,
259 shutdown_timeout: float = 60.0,
260 **kwargs: Any,
261 ) -> None:
262 self._handle_signals = handle_signals
263 self._kwargs = kwargs
264 self._server: Server[_Request] | None = None
265 self._sites: list[BaseSite] = []
266 self._shutdown_timeout = shutdown_timeout
268 @property
269 def server(self) -> Server[_Request] | None:
270 return self._server
272 @property
273 def addresses(self) -> list[Any]:
274 ret: list[Any] = []
275 for site in self._sites:
276 server = site._server
277 if server is not None:
278 sockets = server.sockets
279 if sockets is not None:
280 for sock in sockets:
281 ret.append(sock.getsockname())
282 return ret
284 @property
285 def sites(self) -> set[BaseSite]:
286 return set(self._sites)
288 async def setup(self) -> None:
289 loop = asyncio.get_event_loop()
291 if self._handle_signals:
292 try:
293 loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
294 loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
295 except NotImplementedError:
296 # add_signal_handler is not implemented on Windows
297 pass
299 self._server = await self._make_server()
301 @abstractmethod
302 async def shutdown(self) -> None:
303 """Call any shutdown hooks to help server close gracefully."""
305 async def cleanup(self) -> None:
306 # The loop over sites is intentional, an exception on gather()
307 # leaves self._sites in unpredictable state.
308 # The loop guarantees that a site is either deleted on success or
309 # still present on failure
310 for site in list(self._sites):
311 await site.stop()
313 if self._server: # If setup succeeded
314 # Yield to event loop to ensure incoming requests prior to stopping the sites
315 # have all started to be handled before we proceed to close idle connections.
316 await asyncio.sleep(0)
317 self._server.pre_shutdown()
318 await self.shutdown()
319 await self._server.shutdown(self._shutdown_timeout)
320 await self._cleanup_server()
322 self._server = None
323 if self._handle_signals:
324 loop = asyncio.get_running_loop()
325 try:
326 loop.remove_signal_handler(signal.SIGINT)
327 loop.remove_signal_handler(signal.SIGTERM)
328 except NotImplementedError:
329 # remove_signal_handler is not implemented on Windows
330 pass
332 @abstractmethod
333 async def _make_server(self) -> Server[_Request]:
334 """Return a new server for the runner to serve requests."""
336 @abstractmethod
337 async def _cleanup_server(self) -> None:
338 """Run any cleanup steps after the server is shutdown."""
340 def _reg_site(self, site: BaseSite) -> None:
341 if site in self._sites:
342 raise RuntimeError(f"Site {site} is already registered in runner {self}")
343 self._sites.append(site)
345 def _check_site(self, site: BaseSite) -> None:
346 if site not in self._sites:
347 raise RuntimeError(f"Site {site} is not registered in runner {self}")
349 def _unreg_site(self, site: BaseSite) -> None:
350 if site not in self._sites:
351 raise RuntimeError(f"Site {site} is not registered in runner {self}")
352 self._sites.remove(site)
355class ServerRunner(BaseRunner[BaseRequest]):
356 """Low-level web server runner"""
358 __slots__ = ("_web_server",)
360 def __init__(
361 self,
362 web_server: Server[BaseRequest],
363 *,
364 handle_signals: bool = False,
365 **kwargs: Any,
366 ) -> None:
367 super().__init__(handle_signals=handle_signals, **kwargs)
368 self._web_server = web_server
370 async def shutdown(self) -> None:
371 pass
373 async def _make_server(self) -> Server[BaseRequest]:
374 return self._web_server
376 async def _cleanup_server(self) -> None:
377 pass
380class AppRunner(BaseRunner[Request]):
381 """Web Application runner"""
383 __slots__ = ("_app",)
385 def __init__(
386 self,
387 app: Application,
388 *,
389 handle_signals: bool = False,
390 access_log_class: type[AbstractAccessLogger] = AccessLogger,
391 **kwargs: Any,
392 ) -> None:
393 if not isinstance(app, Application):
394 raise TypeError(
395 f"The first argument should be web.Application instance, got {app!r}"
396 )
397 kwargs["access_log_class"] = access_log_class
399 if app._handler_args:
400 for k, v in app._handler_args.items():
401 kwargs[k] = v
403 if not issubclass(kwargs["access_log_class"], AbstractAccessLogger):
404 raise TypeError(
405 "access_log_class must be subclass of "
406 "aiohttp.abc.AbstractAccessLogger, got {}".format(
407 kwargs["access_log_class"]
408 )
409 )
411 super().__init__(handle_signals=handle_signals, **kwargs)
412 self._app = app
414 @property
415 def app(self) -> Application:
416 return self._app
418 async def shutdown(self) -> None:
419 await self._app.shutdown()
421 async def _make_server(self) -> Server[Request]:
422 self._app.on_startup.freeze()
423 await self._app.startup()
424 self._app.freeze()
426 return Server(
427 self._app._handle,
428 request_factory=self._make_request,
429 **self._kwargs,
430 )
432 def _make_request(
433 self,
434 message: RawRequestMessage,
435 payload: StreamReader,
436 protocol: RequestHandler[Request],
437 writer: AbstractStreamWriter,
438 task: "asyncio.Task[None]",
439 _cls: type[Request] = Request,
440 ) -> Request:
441 loop = asyncio.get_running_loop()
442 return _cls(
443 message,
444 payload,
445 protocol,
446 writer,
447 task,
448 loop,
449 client_max_size=self.app._client_max_size,
450 )
452 async def _cleanup_server(self) -> None:
453 await self._app.cleanup()