1from __future__ import annotations
2
3import contextlib
4import functools
5import inspect
6import re
7import traceback
8import types
9import warnings
10from collections.abc import Awaitable, Collection, Generator, Sequence
11from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
12from enum import Enum
13from re import Pattern
14from typing import Any, Callable, TypeVar
15
16from starlette._exception_handler import wrap_app_handling_exceptions
17from starlette._utils import get_route_path, is_async_callable
18from starlette.concurrency import run_in_threadpool
19from starlette.convertors import CONVERTOR_TYPES, Convertor
20from starlette.datastructures import URL, Headers, URLPath
21from starlette.exceptions import HTTPException
22from starlette.middleware import Middleware
23from starlette.requests import Request
24from starlette.responses import PlainTextResponse, RedirectResponse, Response
25from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
26from starlette.websockets import WebSocket, WebSocketClose
27
28
29class NoMatchFound(Exception):
30 """
31 Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
32 if no matching route exists.
33 """
34
35 def __init__(self, name: str, path_params: dict[str, Any]) -> None:
36 params = ", ".join(list(path_params.keys()))
37 super().__init__(f'No route exists for name "{name}" and params "{params}".')
38
39
40class Match(Enum):
41 NONE = 0
42 PARTIAL = 1
43 FULL = 2
44
45
46def iscoroutinefunction_or_partial(obj: Any) -> bool: # pragma: no cover
47 """
48 Correctly determines if an object is a coroutine function,
49 including those wrapped in functools.partial objects.
50 """
51 warnings.warn(
52 "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
53 DeprecationWarning,
54 )
55 while isinstance(obj, functools.partial):
56 obj = obj.func
57 return inspect.iscoroutinefunction(obj)
58
59
60def request_response(
61 func: Callable[[Request], Awaitable[Response] | Response],
62) -> ASGIApp:
63 """
64 Takes a function or coroutine `func(request) -> response`,
65 and returns an ASGI application.
66 """
67 f: Callable[[Request], Awaitable[Response]] = (
68 func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
69 )
70
71 async def app(scope: Scope, receive: Receive, send: Send) -> None:
72 request = Request(scope, receive, send)
73
74 async def app(scope: Scope, receive: Receive, send: Send) -> None:
75 response = await f(request)
76 await response(scope, receive, send)
77
78 await wrap_app_handling_exceptions(app, request)(scope, receive, send)
79
80 return app
81
82
83def websocket_session(
84 func: Callable[[WebSocket], Awaitable[None]],
85) -> ASGIApp:
86 """
87 Takes a coroutine `func(session)`, and returns an ASGI application.
88 """
89 # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
90
91 async def app(scope: Scope, receive: Receive, send: Send) -> None:
92 session = WebSocket(scope, receive=receive, send=send)
93
94 async def app(scope: Scope, receive: Receive, send: Send) -> None:
95 await func(session)
96
97 await wrap_app_handling_exceptions(app, session)(scope, receive, send)
98
99 return app
100
101
102def get_name(endpoint: Callable[..., Any]) -> str:
103 return getattr(endpoint, "__name__", endpoint.__class__.__name__)
104
105
106def replace_params(
107 path: str,
108 param_convertors: dict[str, Convertor[Any]],
109 path_params: dict[str, str],
110) -> tuple[str, dict[str, str]]:
111 for key, value in list(path_params.items()):
112 if "{" + key + "}" in path:
113 convertor = param_convertors[key]
114 value = convertor.to_string(value)
115 path = path.replace("{" + key + "}", value)
116 path_params.pop(key)
117 return path, path_params
118
119
120# Match parameters in URL paths, eg. '{param}', and '{param:int}'
121PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
122
123
124def compile_path(
125 path: str,
126) -> tuple[Pattern[str], str, dict[str, Convertor[Any]]]:
127 """
128 Given a path string, like: "/{username:str}",
129 or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
130 of (regex, format, {param_name:convertor}).
131
132 regex: "/(?P<username>[^/]+)"
133 format: "/{username}"
134 convertors: {"username": StringConvertor()}
135 """
136 is_host = not path.startswith("/")
137
138 path_regex = "^"
139 path_format = ""
140 duplicated_params = set()
141
142 idx = 0
143 param_convertors = {}
144 for match in PARAM_REGEX.finditer(path):
145 param_name, convertor_type = match.groups("str")
146 convertor_type = convertor_type.lstrip(":")
147 assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
148 convertor = CONVERTOR_TYPES[convertor_type]
149
150 path_regex += re.escape(path[idx : match.start()])
151 path_regex += f"(?P<{param_name}>{convertor.regex})"
152
153 path_format += path[idx : match.start()]
154 path_format += "{%s}" % param_name
155
156 if param_name in param_convertors:
157 duplicated_params.add(param_name)
158
159 param_convertors[param_name] = convertor
160
161 idx = match.end()
162
163 if duplicated_params:
164 names = ", ".join(sorted(duplicated_params))
165 ending = "s" if len(duplicated_params) > 1 else ""
166 raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
167
168 if is_host:
169 # Align with `Host.matches()` behavior, which ignores port.
170 hostname = path[idx:].split(":")[0]
171 path_regex += re.escape(hostname) + "$"
172 else:
173 path_regex += re.escape(path[idx:]) + "$"
174
175 path_format += path[idx:]
176
177 return re.compile(path_regex), path_format, param_convertors
178
179
180class BaseRoute:
181 def matches(self, scope: Scope) -> tuple[Match, Scope]:
182 raise NotImplementedError() # pragma: no cover
183
184 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
185 raise NotImplementedError() # pragma: no cover
186
187 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
188 raise NotImplementedError() # pragma: no cover
189
190 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
191 """
192 A route may be used in isolation as a stand-alone ASGI app.
193 This is a somewhat contrived case, as they'll almost always be used
194 within a Router, but could be useful for some tooling and minimal apps.
195 """
196 match, child_scope = self.matches(scope)
197 if match == Match.NONE:
198 if scope["type"] == "http":
199 response = PlainTextResponse("Not Found", status_code=404)
200 await response(scope, receive, send)
201 elif scope["type"] == "websocket": # pragma: no branch
202 websocket_close = WebSocketClose()
203 await websocket_close(scope, receive, send)
204 return
205
206 scope.update(child_scope)
207 await self.handle(scope, receive, send)
208
209
210class Route(BaseRoute):
211 def __init__(
212 self,
213 path: str,
214 endpoint: Callable[..., Any],
215 *,
216 methods: Collection[str] | None = None,
217 name: str | None = None,
218 include_in_schema: bool = True,
219 middleware: Sequence[Middleware] | None = None,
220 ) -> None:
221 assert path.startswith("/"), "Routed paths must start with '/'"
222 self.path = path
223 self.endpoint = endpoint
224 self.name = get_name(endpoint) if name is None else name
225 self.include_in_schema = include_in_schema
226
227 endpoint_handler = endpoint
228 while isinstance(endpoint_handler, functools.partial):
229 endpoint_handler = endpoint_handler.func
230 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
231 # Endpoint is function or method. Treat it as `func(request) -> response`.
232 self.app = request_response(endpoint)
233 if methods is None:
234 methods = ["GET"]
235 else:
236 # Endpoint is a class. Treat it as ASGI.
237 self.app = endpoint
238
239 if middleware is not None:
240 for cls, args, kwargs in reversed(middleware):
241 self.app = cls(self.app, *args, **kwargs)
242
243 if methods is None:
244 self.methods = None
245 else:
246 self.methods = {method.upper() for method in methods}
247 if "GET" in self.methods:
248 self.methods.add("HEAD")
249
250 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
251
252 def matches(self, scope: Scope) -> tuple[Match, Scope]:
253 path_params: dict[str, Any]
254 if scope["type"] == "http":
255 route_path = get_route_path(scope)
256 match = self.path_regex.match(route_path)
257 if match:
258 matched_params = match.groupdict()
259 for key, value in matched_params.items():
260 matched_params[key] = self.param_convertors[key].convert(value)
261 path_params = dict(scope.get("path_params", {}))
262 path_params.update(matched_params)
263 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
264 if self.methods and scope["method"] not in self.methods:
265 return Match.PARTIAL, child_scope
266 else:
267 return Match.FULL, child_scope
268 return Match.NONE, {}
269
270 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
271 seen_params = set(path_params.keys())
272 expected_params = set(self.param_convertors.keys())
273
274 if name != self.name or seen_params != expected_params:
275 raise NoMatchFound(name, path_params)
276
277 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
278 assert not remaining_params
279 return URLPath(path=path, protocol="http")
280
281 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
282 if self.methods and scope["method"] not in self.methods:
283 headers = {"Allow": ", ".join(self.methods)}
284 if "app" in scope:
285 raise HTTPException(status_code=405, headers=headers)
286 else:
287 response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
288 await response(scope, receive, send)
289 else:
290 await self.app(scope, receive, send)
291
292 def __eq__(self, other: Any) -> bool:
293 return (
294 isinstance(other, Route)
295 and self.path == other.path
296 and self.endpoint == other.endpoint
297 and self.methods == other.methods
298 )
299
300 def __repr__(self) -> str:
301 class_name = self.__class__.__name__
302 methods = sorted(self.methods or [])
303 path, name = self.path, self.name
304 return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
305
306
307class WebSocketRoute(BaseRoute):
308 def __init__(
309 self,
310 path: str,
311 endpoint: Callable[..., Any],
312 *,
313 name: str | None = None,
314 middleware: Sequence[Middleware] | None = None,
315 ) -> None:
316 assert path.startswith("/"), "Routed paths must start with '/'"
317 self.path = path
318 self.endpoint = endpoint
319 self.name = get_name(endpoint) if name is None else name
320
321 endpoint_handler = endpoint
322 while isinstance(endpoint_handler, functools.partial):
323 endpoint_handler = endpoint_handler.func
324 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
325 # Endpoint is function or method. Treat it as `func(websocket)`.
326 self.app = websocket_session(endpoint)
327 else:
328 # Endpoint is a class. Treat it as ASGI.
329 self.app = endpoint
330
331 if middleware is not None:
332 for cls, args, kwargs in reversed(middleware):
333 self.app = cls(self.app, *args, **kwargs)
334
335 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
336
337 def matches(self, scope: Scope) -> tuple[Match, Scope]:
338 path_params: dict[str, Any]
339 if scope["type"] == "websocket":
340 route_path = get_route_path(scope)
341 match = self.path_regex.match(route_path)
342 if match:
343 matched_params = match.groupdict()
344 for key, value in matched_params.items():
345 matched_params[key] = self.param_convertors[key].convert(value)
346 path_params = dict(scope.get("path_params", {}))
347 path_params.update(matched_params)
348 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
349 return Match.FULL, child_scope
350 return Match.NONE, {}
351
352 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
353 seen_params = set(path_params.keys())
354 expected_params = set(self.param_convertors.keys())
355
356 if name != self.name or seen_params != expected_params:
357 raise NoMatchFound(name, path_params)
358
359 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
360 assert not remaining_params
361 return URLPath(path=path, protocol="websocket")
362
363 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
364 await self.app(scope, receive, send)
365
366 def __eq__(self, other: Any) -> bool:
367 return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
368
369 def __repr__(self) -> str:
370 return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
371
372
373class Mount(BaseRoute):
374 def __init__(
375 self,
376 path: str,
377 app: ASGIApp | None = None,
378 routes: Sequence[BaseRoute] | None = None,
379 name: str | None = None,
380 *,
381 middleware: Sequence[Middleware] | None = None,
382 ) -> None:
383 assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
384 assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
385 self.path = path.rstrip("/")
386 if app is not None:
387 self._base_app: ASGIApp = app
388 else:
389 self._base_app = Router(routes=routes)
390 self.app = self._base_app
391 if middleware is not None:
392 for cls, args, kwargs in reversed(middleware):
393 self.app = cls(self.app, *args, **kwargs)
394 self.name = name
395 self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
396
397 @property
398 def routes(self) -> list[BaseRoute]:
399 return getattr(self._base_app, "routes", [])
400
401 def matches(self, scope: Scope) -> tuple[Match, Scope]:
402 path_params: dict[str, Any]
403 if scope["type"] in ("http", "websocket"): # pragma: no branch
404 root_path = scope.get("root_path", "")
405 route_path = get_route_path(scope)
406 match = self.path_regex.match(route_path)
407 if match:
408 matched_params = match.groupdict()
409 for key, value in matched_params.items():
410 matched_params[key] = self.param_convertors[key].convert(value)
411 remaining_path = "/" + matched_params.pop("path")
412 matched_path = route_path[: -len(remaining_path)]
413 path_params = dict(scope.get("path_params", {}))
414 path_params.update(matched_params)
415 child_scope = {
416 "path_params": path_params,
417 # app_root_path will only be set at the top level scope,
418 # initialized with the (optional) value of a root_path
419 # set above/before Starlette. And even though any
420 # mount will have its own child scope with its own respective
421 # root_path, the app_root_path will always be available in all
422 # the child scopes with the same top level value because it's
423 # set only once here with a default, any other child scope will
424 # just inherit that app_root_path default value stored in the
425 # scope. All this is needed to support Request.url_for(), as it
426 # uses the app_root_path to build the URL path.
427 "app_root_path": scope.get("app_root_path", root_path),
428 "root_path": root_path + matched_path,
429 "endpoint": self.app,
430 }
431 return Match.FULL, child_scope
432 return Match.NONE, {}
433
434 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
435 if self.name is not None and name == self.name and "path" in path_params:
436 # 'name' matches "<mount_name>".
437 path_params["path"] = path_params["path"].lstrip("/")
438 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
439 if not remaining_params:
440 return URLPath(path=path)
441 elif self.name is None or name.startswith(self.name + ":"):
442 if self.name is None:
443 # No mount name.
444 remaining_name = name
445 else:
446 # 'name' matches "<mount_name>:<child_name>".
447 remaining_name = name[len(self.name) + 1 :]
448 path_kwarg = path_params.get("path")
449 path_params["path"] = ""
450 path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
451 if path_kwarg is not None:
452 remaining_params["path"] = path_kwarg
453 for route in self.routes or []:
454 try:
455 url = route.url_path_for(remaining_name, **remaining_params)
456 return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
457 except NoMatchFound:
458 pass
459 raise NoMatchFound(name, path_params)
460
461 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
462 await self.app(scope, receive, send)
463
464 def __eq__(self, other: Any) -> bool:
465 return isinstance(other, Mount) and self.path == other.path and self.app == other.app
466
467 def __repr__(self) -> str:
468 class_name = self.__class__.__name__
469 name = self.name or ""
470 return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
471
472
473class Host(BaseRoute):
474 def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
475 assert not host.startswith("/"), "Host must not start with '/'"
476 self.host = host
477 self.app = app
478 self.name = name
479 self.host_regex, self.host_format, self.param_convertors = compile_path(host)
480
481 @property
482 def routes(self) -> list[BaseRoute]:
483 return getattr(self.app, "routes", [])
484
485 def matches(self, scope: Scope) -> tuple[Match, Scope]:
486 if scope["type"] in ("http", "websocket"): # pragma:no branch
487 headers = Headers(scope=scope)
488 host = headers.get("host", "").split(":")[0]
489 match = self.host_regex.match(host)
490 if match:
491 matched_params = match.groupdict()
492 for key, value in matched_params.items():
493 matched_params[key] = self.param_convertors[key].convert(value)
494 path_params = dict(scope.get("path_params", {}))
495 path_params.update(matched_params)
496 child_scope = {"path_params": path_params, "endpoint": self.app}
497 return Match.FULL, child_scope
498 return Match.NONE, {}
499
500 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
501 if self.name is not None and name == self.name and "path" in path_params:
502 # 'name' matches "<mount_name>".
503 path = path_params.pop("path")
504 host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
505 if not remaining_params:
506 return URLPath(path=path, host=host)
507 elif self.name is None or name.startswith(self.name + ":"):
508 if self.name is None:
509 # No mount name.
510 remaining_name = name
511 else:
512 # 'name' matches "<mount_name>:<child_name>".
513 remaining_name = name[len(self.name) + 1 :]
514 host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
515 for route in self.routes or []:
516 try:
517 url = route.url_path_for(remaining_name, **remaining_params)
518 return URLPath(path=str(url), protocol=url.protocol, host=host)
519 except NoMatchFound:
520 pass
521 raise NoMatchFound(name, path_params)
522
523 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
524 await self.app(scope, receive, send)
525
526 def __eq__(self, other: Any) -> bool:
527 return isinstance(other, Host) and self.host == other.host and self.app == other.app
528
529 def __repr__(self) -> str:
530 class_name = self.__class__.__name__
531 name = self.name or ""
532 return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
533
534
535_T = TypeVar("_T")
536
537
538class _AsyncLiftContextManager(AbstractAsyncContextManager[_T]):
539 def __init__(self, cm: AbstractContextManager[_T]):
540 self._cm = cm
541
542 async def __aenter__(self) -> _T:
543 return self._cm.__enter__()
544
545 async def __aexit__(
546 self,
547 exc_type: type[BaseException] | None,
548 exc_value: BaseException | None,
549 traceback: types.TracebackType | None,
550 ) -> bool | None:
551 return self._cm.__exit__(exc_type, exc_value, traceback)
552
553
554def _wrap_gen_lifespan_context(
555 lifespan_context: Callable[[Any], Generator[Any, Any, Any]],
556) -> Callable[[Any], AbstractAsyncContextManager[Any]]:
557 cmgr = contextlib.contextmanager(lifespan_context)
558
559 @functools.wraps(cmgr)
560 def wrapper(app: Any) -> _AsyncLiftContextManager[Any]:
561 return _AsyncLiftContextManager(cmgr(app))
562
563 return wrapper
564
565
566class _DefaultLifespan:
567 def __init__(self, router: Router):
568 self._router = router
569
570 async def __aenter__(self) -> None:
571 await self._router.startup()
572
573 async def __aexit__(self, *exc_info: object) -> None:
574 await self._router.shutdown()
575
576 def __call__(self: _T, app: object) -> _T:
577 return self
578
579
580class Router:
581 def __init__(
582 self,
583 routes: Sequence[BaseRoute] | None = None,
584 redirect_slashes: bool = True,
585 default: ASGIApp | None = None,
586 on_startup: Sequence[Callable[[], Any]] | None = None,
587 on_shutdown: Sequence[Callable[[], Any]] | None = None,
588 # the generic to Lifespan[AppType] is the type of the top level application
589 # which the router cannot know statically, so we use Any
590 lifespan: Lifespan[Any] | None = None,
591 *,
592 middleware: Sequence[Middleware] | None = None,
593 ) -> None:
594 self.routes = [] if routes is None else list(routes)
595 self.redirect_slashes = redirect_slashes
596 self.default = self.not_found if default is None else default
597 self.on_startup = [] if on_startup is None else list(on_startup)
598 self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
599
600 if on_startup or on_shutdown:
601 warnings.warn(
602 "The on_startup and on_shutdown parameters are deprecated, and they "
603 "will be removed on version 1.0. Use the lifespan parameter instead. "
604 "See more about it on https://www.starlette.io/lifespan/.",
605 DeprecationWarning,
606 )
607 if lifespan:
608 warnings.warn(
609 "The `lifespan` parameter cannot be used with `on_startup` or "
610 "`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
611 "ignored."
612 )
613
614 if lifespan is None:
615 self.lifespan_context: Lifespan[Any] = _DefaultLifespan(self)
616
617 elif inspect.isasyncgenfunction(lifespan):
618 warnings.warn(
619 "async generator function lifespans are deprecated, "
620 "use an @contextlib.asynccontextmanager function instead",
621 DeprecationWarning,
622 )
623 self.lifespan_context = asynccontextmanager(
624 lifespan,
625 )
626 elif inspect.isgeneratorfunction(lifespan):
627 warnings.warn(
628 "generator function lifespans are deprecated, use an @contextlib.asynccontextmanager function instead",
629 DeprecationWarning,
630 )
631 self.lifespan_context = _wrap_gen_lifespan_context(
632 lifespan,
633 )
634 else:
635 self.lifespan_context = lifespan
636
637 self.middleware_stack = self.app
638 if middleware:
639 for cls, args, kwargs in reversed(middleware):
640 self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
641
642 async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
643 if scope["type"] == "websocket":
644 websocket_close = WebSocketClose()
645 await websocket_close(scope, receive, send)
646 return
647
648 # If we're running inside a starlette application then raise an
649 # exception, so that the configurable exception handler can deal with
650 # returning the response. For plain ASGI apps, just return the response.
651 if "app" in scope:
652 raise HTTPException(status_code=404)
653 else:
654 response = PlainTextResponse("Not Found", status_code=404)
655 await response(scope, receive, send)
656
657 def url_path_for(self, name: str, /, **path_params: Any) -> URLPath:
658 for route in self.routes:
659 try:
660 return route.url_path_for(name, **path_params)
661 except NoMatchFound:
662 pass
663 raise NoMatchFound(name, path_params)
664
665 async def startup(self) -> None:
666 """
667 Run any `.on_startup` event handlers.
668 """
669 for handler in self.on_startup:
670 if is_async_callable(handler):
671 await handler()
672 else:
673 handler()
674
675 async def shutdown(self) -> None:
676 """
677 Run any `.on_shutdown` event handlers.
678 """
679 for handler in self.on_shutdown:
680 if is_async_callable(handler):
681 await handler()
682 else:
683 handler()
684
685 async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
686 """
687 Handle ASGI lifespan messages, which allows us to manage application
688 startup and shutdown events.
689 """
690 started = False
691 app: Any = scope.get("app")
692 await receive()
693 try:
694 async with self.lifespan_context(app) as maybe_state:
695 if maybe_state is not None:
696 if "state" not in scope:
697 raise RuntimeError('The server does not support "state" in the lifespan scope.')
698 scope["state"].update(maybe_state)
699 await send({"type": "lifespan.startup.complete"})
700 started = True
701 await receive()
702 except BaseException:
703 exc_text = traceback.format_exc()
704 if started:
705 await send({"type": "lifespan.shutdown.failed", "message": exc_text})
706 else:
707 await send({"type": "lifespan.startup.failed", "message": exc_text})
708 raise
709 else:
710 await send({"type": "lifespan.shutdown.complete"})
711
712 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
713 """
714 The main entry point to the Router class.
715 """
716 await self.middleware_stack(scope, receive, send)
717
718 async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
719 assert scope["type"] in ("http", "websocket", "lifespan")
720
721 if "router" not in scope:
722 scope["router"] = self
723
724 if scope["type"] == "lifespan":
725 await self.lifespan(scope, receive, send)
726 return
727
728 partial = None
729
730 for route in self.routes:
731 # Determine if any route matches the incoming scope,
732 # and hand over to the matching route if found.
733 match, child_scope = route.matches(scope)
734 if match == Match.FULL:
735 scope.update(child_scope)
736 await route.handle(scope, receive, send)
737 return
738 elif match == Match.PARTIAL and partial is None:
739 partial = route
740 partial_scope = child_scope
741
742 if partial is not None:
743 # Handle partial matches. These are cases where an endpoint is
744 # able to handle the request, but is not a preferred option.
745 # We use this in particular to deal with "405 Method Not Allowed".
746 scope.update(partial_scope)
747 await partial.handle(scope, receive, send)
748 return
749
750 route_path = get_route_path(scope)
751 if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
752 redirect_scope = dict(scope)
753 if route_path.endswith("/"):
754 redirect_scope["path"] = redirect_scope["path"].rstrip("/")
755 else:
756 redirect_scope["path"] = redirect_scope["path"] + "/"
757
758 for route in self.routes:
759 match, child_scope = route.matches(redirect_scope)
760 if match != Match.NONE:
761 redirect_url = URL(scope=redirect_scope)
762 response = RedirectResponse(url=str(redirect_url))
763 await response(scope, receive, send)
764 return
765
766 await self.default(scope, receive, send)
767
768 def __eq__(self, other: Any) -> bool:
769 return isinstance(other, Router) and self.routes == other.routes
770
771 def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
772 route = Mount(path, app=app, name=name)
773 self.routes.append(route)
774
775 def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
776 route = Host(host, app=app, name=name)
777 self.routes.append(route)
778
779 def add_route(
780 self,
781 path: str,
782 endpoint: Callable[[Request], Awaitable[Response] | Response],
783 methods: Collection[str] | None = None,
784 name: str | None = None,
785 include_in_schema: bool = True,
786 ) -> None: # pragma: no cover
787 route = Route(
788 path,
789 endpoint=endpoint,
790 methods=methods,
791 name=name,
792 include_in_schema=include_in_schema,
793 )
794 self.routes.append(route)
795
796 def add_websocket_route(
797 self,
798 path: str,
799 endpoint: Callable[[WebSocket], Awaitable[None]],
800 name: str | None = None,
801 ) -> None: # pragma: no cover
802 route = WebSocketRoute(path, endpoint=endpoint, name=name)
803 self.routes.append(route)
804
805 def route(
806 self,
807 path: str,
808 methods: Collection[str] | None = None,
809 name: str | None = None,
810 include_in_schema: bool = True,
811 ) -> Callable: # type: ignore[type-arg]
812 """
813 We no longer document this decorator style API, and its usage is discouraged.
814 Instead you should use the following approach:
815
816 >>> routes = [Route(path, endpoint=...), ...]
817 >>> app = Starlette(routes=routes)
818 """
819 warnings.warn(
820 "The `route` decorator is deprecated, and will be removed in version 1.0.0."
821 "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
822 DeprecationWarning,
823 )
824
825 def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
826 self.add_route(
827 path,
828 func,
829 methods=methods,
830 name=name,
831 include_in_schema=include_in_schema,
832 )
833 return func
834
835 return decorator
836
837 def websocket_route(self, path: str, name: str | None = None) -> Callable: # type: ignore[type-arg]
838 """
839 We no longer document this decorator style API, and its usage is discouraged.
840 Instead you should use the following approach:
841
842 >>> routes = [WebSocketRoute(path, endpoint=...), ...]
843 >>> app = Starlette(routes=routes)
844 """
845 warnings.warn(
846 "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
847 "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
848 DeprecationWarning,
849 )
850
851 def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
852 self.add_websocket_route(path, func, name=name)
853 return func
854
855 return decorator
856
857 def add_event_handler(self, event_type: str, func: Callable[[], Any]) -> None: # pragma: no cover
858 assert event_type in ("startup", "shutdown")
859
860 if event_type == "startup":
861 self.on_startup.append(func)
862 else:
863 self.on_shutdown.append(func)
864
865 def on_event(self, event_type: str) -> Callable: # type: ignore[type-arg]
866 warnings.warn(
867 "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
868 "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
869 DeprecationWarning,
870 )
871
872 def decorator(func: Callable) -> Callable: # type: ignore[type-arg]
873 self.add_event_handler(event_type, func)
874 return func
875
876 return decorator