1from __future__ import annotations
2
3import contextlib
4import functools
5import inspect
6import re
7import traceback
8import types
9import typing
10import warnings
11from contextlib import asynccontextmanager
12from enum import Enum
13
14from starlette._exception_handler import wrap_app_handling_exceptions
15from starlette._utils import get_route_path, is_async_callable
16from starlette.concurrency import run_in_threadpool
17from starlette.convertors import CONVERTOR_TYPES, Convertor
18from starlette.datastructures import URL, Headers, URLPath
19from starlette.exceptions import HTTPException
20from starlette.middleware import Middleware
21from starlette.requests import Request
22from starlette.responses import PlainTextResponse, RedirectResponse, Response
23from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
24from starlette.websockets import WebSocket, WebSocketClose
25
26
27class NoMatchFound(Exception):
28 """
29 Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
30 if no matching route exists.
31 """
32
33 def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None:
34 params = ", ".join(list(path_params.keys()))
35 super().__init__(f'No route exists for name "{name}" and params "{params}".')
36
37
38class Match(Enum):
39 NONE = 0
40 PARTIAL = 1
41 FULL = 2
42
43
44def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
45 """
46 Correctly determines if an object is a coroutine function,
47 including those wrapped in functools.partial objects.
48 """
49 warnings.warn(
50 "iscoroutinefunction_or_partial is deprecated, and will be removed in a future release.",
51 DeprecationWarning,
52 )
53 while isinstance(obj, functools.partial):
54 obj = obj.func
55 return inspect.iscoroutinefunction(obj)
56
57
58def request_response(
59 func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
60) -> ASGIApp:
61 """
62 Takes a function or coroutine `func(request) -> response`,
63 and returns an ASGI application.
64 """
65 f: typing.Callable[[Request], typing.Awaitable[Response]] = (
66 func if is_async_callable(func) else functools.partial(run_in_threadpool, func) # type:ignore
67 )
68
69 async def app(scope: Scope, receive: Receive, send: Send) -> None:
70 request = Request(scope, receive, send)
71
72 async def app(scope: Scope, receive: Receive, send: Send) -> None:
73 response = await f(request)
74 await response(scope, receive, send)
75
76 await wrap_app_handling_exceptions(app, request)(scope, receive, send)
77
78 return app
79
80
81def websocket_session(
82 func: typing.Callable[[WebSocket], typing.Awaitable[None]],
83) -> ASGIApp:
84 """
85 Takes a coroutine `func(session)`, and returns an ASGI application.
86 """
87 # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
88
89 async def app(scope: Scope, receive: Receive, send: Send) -> None:
90 session = WebSocket(scope, receive=receive, send=send)
91
92 async def app(scope: Scope, receive: Receive, send: Send) -> None:
93 await func(session)
94
95 await wrap_app_handling_exceptions(app, session)(scope, receive, send)
96
97 return app
98
99
100def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
101 return getattr(endpoint, "__name__", endpoint.__class__.__name__)
102
103
104def replace_params(
105 path: str,
106 param_convertors: dict[str, Convertor[typing.Any]],
107 path_params: dict[str, str],
108) -> tuple[str, dict[str, str]]:
109 for key, value in list(path_params.items()):
110 if "{" + key + "}" in path:
111 convertor = param_convertors[key]
112 value = convertor.to_string(value)
113 path = path.replace("{" + key + "}", value)
114 path_params.pop(key)
115 return path, path_params
116
117
118# Match parameters in URL paths, eg. '{param}', and '{param:int}'
119PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
120
121
122def compile_path(
123 path: str,
124) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]:
125 """
126 Given a path string, like: "/{username:str}",
127 or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
128 of (regex, format, {param_name:convertor}).
129
130 regex: "/(?P<username>[^/]+)"
131 format: "/{username}"
132 convertors: {"username": StringConvertor()}
133 """
134 is_host = not path.startswith("/")
135
136 path_regex = "^"
137 path_format = ""
138 duplicated_params = set()
139
140 idx = 0
141 param_convertors = {}
142 for match in PARAM_REGEX.finditer(path):
143 param_name, convertor_type = match.groups("str")
144 convertor_type = convertor_type.lstrip(":")
145 assert convertor_type in CONVERTOR_TYPES, f"Unknown path convertor '{convertor_type}'"
146 convertor = CONVERTOR_TYPES[convertor_type]
147
148 path_regex += re.escape(path[idx : match.start()])
149 path_regex += f"(?P<{param_name}>{convertor.regex})"
150
151 path_format += path[idx : match.start()]
152 path_format += "{%s}" % param_name
153
154 if param_name in param_convertors:
155 duplicated_params.add(param_name)
156
157 param_convertors[param_name] = convertor
158
159 idx = match.end()
160
161 if duplicated_params:
162 names = ", ".join(sorted(duplicated_params))
163 ending = "s" if len(duplicated_params) > 1 else ""
164 raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
165
166 if is_host:
167 # Align with `Host.matches()` behavior, which ignores port.
168 hostname = path[idx:].split(":")[0]
169 path_regex += re.escape(hostname) + "$"
170 else:
171 path_regex += re.escape(path[idx:]) + "$"
172
173 path_format += path[idx:]
174
175 return re.compile(path_regex), path_format, param_convertors
176
177
178class BaseRoute:
179 def matches(self, scope: Scope) -> tuple[Match, Scope]:
180 raise NotImplementedError() # pragma: no cover
181
182 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
183 raise NotImplementedError() # pragma: no cover
184
185 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
186 raise NotImplementedError() # pragma: no cover
187
188 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
189 """
190 A route may be used in isolation as a stand-alone ASGI app.
191 This is a somewhat contrived case, as they'll almost always be used
192 within a Router, but could be useful for some tooling and minimal apps.
193 """
194 match, child_scope = self.matches(scope)
195 if match == Match.NONE:
196 if scope["type"] == "http":
197 response = PlainTextResponse("Not Found", status_code=404)
198 await response(scope, receive, send)
199 elif scope["type"] == "websocket":
200 websocket_close = WebSocketClose()
201 await websocket_close(scope, receive, send)
202 return
203
204 scope.update(child_scope)
205 await self.handle(scope, receive, send)
206
207
208class Route(BaseRoute):
209 def __init__(
210 self,
211 path: str,
212 endpoint: typing.Callable[..., typing.Any],
213 *,
214 methods: list[str] | None = None,
215 name: str | None = None,
216 include_in_schema: bool = True,
217 middleware: typing.Sequence[Middleware] | None = None,
218 ) -> None:
219 assert path.startswith("/"), "Routed paths must start with '/'"
220 self.path = path
221 self.endpoint = endpoint
222 self.name = get_name(endpoint) if name is None else name
223 self.include_in_schema = include_in_schema
224
225 endpoint_handler = endpoint
226 while isinstance(endpoint_handler, functools.partial):
227 endpoint_handler = endpoint_handler.func
228 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
229 # Endpoint is function or method. Treat it as `func(request) -> response`.
230 self.app = request_response(endpoint)
231 if methods is None:
232 methods = ["GET"]
233 else:
234 # Endpoint is a class. Treat it as ASGI.
235 self.app = endpoint
236
237 if middleware is not None:
238 for cls, args, kwargs in reversed(middleware):
239 self.app = cls(app=self.app, *args, **kwargs)
240
241 if methods is None:
242 self.methods = None
243 else:
244 self.methods = {method.upper() for method in methods}
245 if "GET" in self.methods:
246 self.methods.add("HEAD")
247
248 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
249
250 def matches(self, scope: Scope) -> tuple[Match, Scope]:
251 path_params: dict[str, typing.Any]
252 if scope["type"] == "http":
253 route_path = get_route_path(scope)
254 match = self.path_regex.match(route_path)
255 if match:
256 matched_params = match.groupdict()
257 for key, value in matched_params.items():
258 matched_params[key] = self.param_convertors[key].convert(value)
259 path_params = dict(scope.get("path_params", {}))
260 path_params.update(matched_params)
261 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
262 if self.methods and scope["method"] not in self.methods:
263 return Match.PARTIAL, child_scope
264 else:
265 return Match.FULL, child_scope
266 return Match.NONE, {}
267
268 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
269 seen_params = set(path_params.keys())
270 expected_params = set(self.param_convertors.keys())
271
272 if name != self.name or seen_params != expected_params:
273 raise NoMatchFound(name, path_params)
274
275 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
276 assert not remaining_params
277 return URLPath(path=path, protocol="http")
278
279 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
280 if self.methods and scope["method"] not in self.methods:
281 headers = {"Allow": ", ".join(self.methods)}
282 if "app" in scope:
283 raise HTTPException(status_code=405, headers=headers)
284 else:
285 response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers)
286 await response(scope, receive, send)
287 else:
288 await self.app(scope, receive, send)
289
290 def __eq__(self, other: typing.Any) -> bool:
291 return (
292 isinstance(other, Route)
293 and self.path == other.path
294 and self.endpoint == other.endpoint
295 and self.methods == other.methods
296 )
297
298 def __repr__(self) -> str:
299 class_name = self.__class__.__name__
300 methods = sorted(self.methods or [])
301 path, name = self.path, self.name
302 return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
303
304
305class WebSocketRoute(BaseRoute):
306 def __init__(
307 self,
308 path: str,
309 endpoint: typing.Callable[..., typing.Any],
310 *,
311 name: str | None = None,
312 middleware: typing.Sequence[Middleware] | None = None,
313 ) -> None:
314 assert path.startswith("/"), "Routed paths must start with '/'"
315 self.path = path
316 self.endpoint = endpoint
317 self.name = get_name(endpoint) if name is None else name
318
319 endpoint_handler = endpoint
320 while isinstance(endpoint_handler, functools.partial):
321 endpoint_handler = endpoint_handler.func
322 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
323 # Endpoint is function or method. Treat it as `func(websocket)`.
324 self.app = websocket_session(endpoint)
325 else:
326 # Endpoint is a class. Treat it as ASGI.
327 self.app = endpoint
328
329 if middleware is not None:
330 for cls, args, kwargs in reversed(middleware):
331 self.app = cls(app=self.app, *args, **kwargs)
332
333 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
334
335 def matches(self, scope: Scope) -> tuple[Match, Scope]:
336 path_params: dict[str, typing.Any]
337 if scope["type"] == "websocket":
338 route_path = get_route_path(scope)
339 match = self.path_regex.match(route_path)
340 if match:
341 matched_params = match.groupdict()
342 for key, value in matched_params.items():
343 matched_params[key] = self.param_convertors[key].convert(value)
344 path_params = dict(scope.get("path_params", {}))
345 path_params.update(matched_params)
346 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
347 return Match.FULL, child_scope
348 return Match.NONE, {}
349
350 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
351 seen_params = set(path_params.keys())
352 expected_params = set(self.param_convertors.keys())
353
354 if name != self.name or seen_params != expected_params:
355 raise NoMatchFound(name, path_params)
356
357 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
358 assert not remaining_params
359 return URLPath(path=path, protocol="websocket")
360
361 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
362 await self.app(scope, receive, send)
363
364 def __eq__(self, other: typing.Any) -> bool:
365 return isinstance(other, WebSocketRoute) and self.path == other.path and self.endpoint == other.endpoint
366
367 def __repr__(self) -> str:
368 return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
369
370
371class Mount(BaseRoute):
372 def __init__(
373 self,
374 path: str,
375 app: ASGIApp | None = None,
376 routes: typing.Sequence[BaseRoute] | None = None,
377 name: str | None = None,
378 *,
379 middleware: typing.Sequence[Middleware] | None = None,
380 ) -> None:
381 assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
382 assert app is not None or routes is not None, "Either 'app=...', or 'routes=' must be specified"
383 self.path = path.rstrip("/")
384 if app is not None:
385 self._base_app: ASGIApp = app
386 else:
387 self._base_app = Router(routes=routes)
388 self.app = self._base_app
389 if middleware is not None:
390 for cls, args, kwargs in reversed(middleware):
391 self.app = cls(app=self.app, *args, **kwargs)
392 self.name = name
393 self.path_regex, self.path_format, self.param_convertors = compile_path(self.path + "/{path:path}")
394
395 @property
396 def routes(self) -> list[BaseRoute]:
397 return getattr(self._base_app, "routes", [])
398
399 def matches(self, scope: Scope) -> tuple[Match, Scope]:
400 path_params: dict[str, typing.Any]
401 if scope["type"] in ("http", "websocket"):
402 root_path = scope.get("root_path", "")
403 route_path = get_route_path(scope)
404 match = self.path_regex.match(route_path)
405 if match:
406 matched_params = match.groupdict()
407 for key, value in matched_params.items():
408 matched_params[key] = self.param_convertors[key].convert(value)
409 remaining_path = "/" + matched_params.pop("path")
410 matched_path = route_path[: -len(remaining_path)]
411 path_params = dict(scope.get("path_params", {}))
412 path_params.update(matched_params)
413 child_scope = {
414 "path_params": path_params,
415 # app_root_path will only be set at the top level scope,
416 # initialized with the (optional) value of a root_path
417 # set above/before Starlette. And even though any
418 # mount will have its own child scope with its own respective
419 # root_path, the app_root_path will always be available in all
420 # the child scopes with the same top level value because it's
421 # set only once here with a default, any other child scope will
422 # just inherit that app_root_path default value stored in the
423 # scope. All this is needed to support Request.url_for(), as it
424 # uses the app_root_path to build the URL path.
425 "app_root_path": scope.get("app_root_path", root_path),
426 "root_path": root_path + matched_path,
427 "endpoint": self.app,
428 }
429 return Match.FULL, child_scope
430 return Match.NONE, {}
431
432 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
433 if self.name is not None and name == self.name and "path" in path_params:
434 # 'name' matches "<mount_name>".
435 path_params["path"] = path_params["path"].lstrip("/")
436 path, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
437 if not remaining_params:
438 return URLPath(path=path)
439 elif self.name is None or name.startswith(self.name + ":"):
440 if self.name is None:
441 # No mount name.
442 remaining_name = name
443 else:
444 # 'name' matches "<mount_name>:<child_name>".
445 remaining_name = name[len(self.name) + 1 :]
446 path_kwarg = path_params.get("path")
447 path_params["path"] = ""
448 path_prefix, remaining_params = replace_params(self.path_format, self.param_convertors, path_params)
449 if path_kwarg is not None:
450 remaining_params["path"] = path_kwarg
451 for route in self.routes or []:
452 try:
453 url = route.url_path_for(remaining_name, **remaining_params)
454 return URLPath(path=path_prefix.rstrip("/") + str(url), protocol=url.protocol)
455 except NoMatchFound:
456 pass
457 raise NoMatchFound(name, path_params)
458
459 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
460 await self.app(scope, receive, send)
461
462 def __eq__(self, other: typing.Any) -> bool:
463 return isinstance(other, Mount) and self.path == other.path and self.app == other.app
464
465 def __repr__(self) -> str:
466 class_name = self.__class__.__name__
467 name = self.name or ""
468 return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
469
470
471class Host(BaseRoute):
472 def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
473 assert not host.startswith("/"), "Host must not start with '/'"
474 self.host = host
475 self.app = app
476 self.name = name
477 self.host_regex, self.host_format, self.param_convertors = compile_path(host)
478
479 @property
480 def routes(self) -> list[BaseRoute]:
481 return getattr(self.app, "routes", [])
482
483 def matches(self, scope: Scope) -> tuple[Match, Scope]:
484 if scope["type"] in ("http", "websocket"):
485 headers = Headers(scope=scope)
486 host = headers.get("host", "").split(":")[0]
487 match = self.host_regex.match(host)
488 if match:
489 matched_params = match.groupdict()
490 for key, value in matched_params.items():
491 matched_params[key] = self.param_convertors[key].convert(value)
492 path_params = dict(scope.get("path_params", {}))
493 path_params.update(matched_params)
494 child_scope = {"path_params": path_params, "endpoint": self.app}
495 return Match.FULL, child_scope
496 return Match.NONE, {}
497
498 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
499 if self.name is not None and name == self.name and "path" in path_params:
500 # 'name' matches "<mount_name>".
501 path = path_params.pop("path")
502 host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
503 if not remaining_params:
504 return URLPath(path=path, host=host)
505 elif self.name is None or name.startswith(self.name + ":"):
506 if self.name is None:
507 # No mount name.
508 remaining_name = name
509 else:
510 # 'name' matches "<mount_name>:<child_name>".
511 remaining_name = name[len(self.name) + 1 :]
512 host, remaining_params = replace_params(self.host_format, self.param_convertors, path_params)
513 for route in self.routes or []:
514 try:
515 url = route.url_path_for(remaining_name, **remaining_params)
516 return URLPath(path=str(url), protocol=url.protocol, host=host)
517 except NoMatchFound:
518 pass
519 raise NoMatchFound(name, path_params)
520
521 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
522 await self.app(scope, receive, send)
523
524 def __eq__(self, other: typing.Any) -> bool:
525 return isinstance(other, Host) and self.host == other.host and self.app == other.app
526
527 def __repr__(self) -> str:
528 class_name = self.__class__.__name__
529 name = self.name or ""
530 return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
531
532
533_T = typing.TypeVar("_T")
534
535
536class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
537 def __init__(self, cm: typing.ContextManager[_T]):
538 self._cm = cm
539
540 async def __aenter__(self) -> _T:
541 return self._cm.__enter__()
542
543 async def __aexit__(
544 self,
545 exc_type: type[BaseException] | None,
546 exc_value: BaseException | None,
547 traceback: types.TracebackType | None,
548 ) -> bool | None:
549 return self._cm.__exit__(exc_type, exc_value, traceback)
550
551
552def _wrap_gen_lifespan_context(
553 lifespan_context: typing.Callable[[typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]],
554) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
555 cmgr = contextlib.contextmanager(lifespan_context)
556
557 @functools.wraps(cmgr)
558 def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
559 return _AsyncLiftContextManager(cmgr(app))
560
561 return wrapper
562
563
564class _DefaultLifespan:
565 def __init__(self, router: Router):
566 self._router = router
567
568 async def __aenter__(self) -> None:
569 await self._router.startup()
570
571 async def __aexit__(self, *exc_info: object) -> None:
572 await self._router.shutdown()
573
574 def __call__(self: _T, app: object) -> _T:
575 return self
576
577
578class Router:
579 def __init__(
580 self,
581 routes: typing.Sequence[BaseRoute] | None = None,
582 redirect_slashes: bool = True,
583 default: ASGIApp | None = None,
584 on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
585 on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
586 # the generic to Lifespan[AppType] is the type of the top level application
587 # which the router cannot know statically, so we use typing.Any
588 lifespan: Lifespan[typing.Any] | None = None,
589 *,
590 middleware: typing.Sequence[Middleware] | None = None,
591 ) -> None:
592 self.routes = [] if routes is None else list(routes)
593 self.redirect_slashes = redirect_slashes
594 self.default = self.not_found if default is None else default
595 self.on_startup = [] if on_startup is None else list(on_startup)
596 self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
597
598 if on_startup or on_shutdown:
599 warnings.warn(
600 "The on_startup and on_shutdown parameters are deprecated, and they "
601 "will be removed on version 1.0. Use the lifespan parameter instead. "
602 "See more about it on https://www.starlette.io/lifespan/.",
603 DeprecationWarning,
604 )
605 if lifespan:
606 warnings.warn(
607 "The `lifespan` parameter cannot be used with `on_startup` or "
608 "`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
609 "ignored."
610 )
611
612 if lifespan is None:
613 self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
614
615 elif inspect.isasyncgenfunction(lifespan):
616 warnings.warn(
617 "async generator function lifespans are deprecated, "
618 "use an @contextlib.asynccontextmanager function instead",
619 DeprecationWarning,
620 )
621 self.lifespan_context = asynccontextmanager(
622 lifespan,
623 )
624 elif inspect.isgeneratorfunction(lifespan):
625 warnings.warn(
626 "generator function lifespans are deprecated, "
627 "use an @contextlib.asynccontextmanager function instead",
628 DeprecationWarning,
629 )
630 self.lifespan_context = _wrap_gen_lifespan_context(
631 lifespan,
632 )
633 else:
634 self.lifespan_context = lifespan
635
636 self.middleware_stack = self.app
637 if middleware:
638 for cls, args, kwargs in reversed(middleware):
639 self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
640
641 async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
642 if scope["type"] == "websocket":
643 websocket_close = WebSocketClose()
644 await websocket_close(scope, receive, send)
645 return
646
647 # If we're running inside a starlette application then raise an
648 # exception, so that the configurable exception handler can deal with
649 # returning the response. For plain ASGI apps, just return the response.
650 if "app" in scope:
651 raise HTTPException(status_code=404)
652 else:
653 response = PlainTextResponse("Not Found", status_code=404)
654 await response(scope, receive, send)
655
656 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
657 for route in self.routes:
658 try:
659 return route.url_path_for(name, **path_params)
660 except NoMatchFound:
661 pass
662 raise NoMatchFound(name, path_params)
663
664 async def startup(self) -> None:
665 """
666 Run any `.on_startup` event handlers.
667 """
668 for handler in self.on_startup:
669 if is_async_callable(handler):
670 await handler()
671 else:
672 handler()
673
674 async def shutdown(self) -> None:
675 """
676 Run any `.on_shutdown` event handlers.
677 """
678 for handler in self.on_shutdown:
679 if is_async_callable(handler):
680 await handler()
681 else:
682 handler()
683
684 async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
685 """
686 Handle ASGI lifespan messages, which allows us to manage application
687 startup and shutdown events.
688 """
689 started = False
690 app: typing.Any = scope.get("app")
691 await receive()
692 try:
693 async with self.lifespan_context(app) as maybe_state:
694 if maybe_state is not None:
695 if "state" not in scope:
696 raise RuntimeError('The server does not support "state" in the lifespan scope.')
697 scope["state"].update(maybe_state)
698 await send({"type": "lifespan.startup.complete"})
699 started = True
700 await receive()
701 except BaseException:
702 exc_text = traceback.format_exc()
703 if started:
704 await send({"type": "lifespan.shutdown.failed", "message": exc_text})
705 else:
706 await send({"type": "lifespan.startup.failed", "message": exc_text})
707 raise
708 else:
709 await send({"type": "lifespan.shutdown.complete"})
710
711 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
712 """
713 The main entry point to the Router class.
714 """
715 await self.middleware_stack(scope, receive, send)
716
717 async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
718 assert scope["type"] in ("http", "websocket", "lifespan")
719
720 if "router" not in scope:
721 scope["router"] = self
722
723 if scope["type"] == "lifespan":
724 await self.lifespan(scope, receive, send)
725 return
726
727 partial = None
728
729 for route in self.routes:
730 # Determine if any route matches the incoming scope,
731 # and hand over to the matching route if found.
732 match, child_scope = route.matches(scope)
733 if match == Match.FULL:
734 scope.update(child_scope)
735 await route.handle(scope, receive, send)
736 return
737 elif match == Match.PARTIAL and partial is None:
738 partial = route
739 partial_scope = child_scope
740
741 if partial is not None:
742 # Handle partial matches. These are cases where an endpoint is
743 # able to handle the request, but is not a preferred option.
744 # We use this in particular to deal with "405 Method Not Allowed".
745 scope.update(partial_scope)
746 await partial.handle(scope, receive, send)
747 return
748
749 route_path = get_route_path(scope)
750 if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
751 redirect_scope = dict(scope)
752 if route_path.endswith("/"):
753 redirect_scope["path"] = redirect_scope["path"].rstrip("/")
754 else:
755 redirect_scope["path"] = redirect_scope["path"] + "/"
756
757 for route in self.routes:
758 match, child_scope = route.matches(redirect_scope)
759 if match != Match.NONE:
760 redirect_url = URL(scope=redirect_scope)
761 response = RedirectResponse(url=str(redirect_url))
762 await response(scope, receive, send)
763 return
764
765 await self.default(scope, receive, send)
766
767 def __eq__(self, other: typing.Any) -> bool:
768 return isinstance(other, Router) and self.routes == other.routes
769
770 def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None: # pragma: nocover
771 route = Mount(path, app=app, name=name)
772 self.routes.append(route)
773
774 def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: # pragma: no cover
775 route = Host(host, app=app, name=name)
776 self.routes.append(route)
777
778 def add_route(
779 self,
780 path: str,
781 endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response],
782 methods: list[str] | None = None,
783 name: str | None = None,
784 include_in_schema: bool = True,
785 ) -> None: # pragma: nocover
786 route = Route(
787 path,
788 endpoint=endpoint,
789 methods=methods,
790 name=name,
791 include_in_schema=include_in_schema,
792 )
793 self.routes.append(route)
794
795 def add_websocket_route(
796 self,
797 path: str,
798 endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
799 name: str | None = None,
800 ) -> None: # pragma: no cover
801 route = WebSocketRoute(path, endpoint=endpoint, name=name)
802 self.routes.append(route)
803
804 def route(
805 self,
806 path: str,
807 methods: list[str] | None = None,
808 name: str | None = None,
809 include_in_schema: bool = True,
810 ) -> typing.Callable: # type: ignore[type-arg]
811 """
812 We no longer document this decorator style API, and its usage is discouraged.
813 Instead you should use the following approach:
814
815 >>> routes = [Route(path, endpoint=...), ...]
816 >>> app = Starlette(routes=routes)
817 """
818 warnings.warn(
819 "The `route` decorator is deprecated, and will be removed in version 1.0.0."
820 "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.",
821 DeprecationWarning,
822 )
823
824 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
825 self.add_route(
826 path,
827 func,
828 methods=methods,
829 name=name,
830 include_in_schema=include_in_schema,
831 )
832 return func
833
834 return decorator
835
836 def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
837 """
838 We no longer document this decorator style API, and its usage is discouraged.
839 Instead you should use the following approach:
840
841 >>> routes = [WebSocketRoute(path, endpoint=...), ...]
842 >>> app = Starlette(routes=routes)
843 """
844 warnings.warn(
845 "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to "
846 "https://www.starlette.io/routing/#websocket-routing for the recommended approach.",
847 DeprecationWarning,
848 )
849
850 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
851 self.add_websocket_route(path, func, name=name)
852 return func
853
854 return decorator
855
856 def add_event_handler(self, event_type: str, func: typing.Callable[[], typing.Any]) -> None: # pragma: no cover
857 assert event_type in ("startup", "shutdown")
858
859 if event_type == "startup":
860 self.on_startup.append(func)
861 else:
862 self.on_shutdown.append(func)
863
864 def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
865 warnings.warn(
866 "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. "
867 "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
868 DeprecationWarning,
869 )
870
871 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg]
872 self.add_event_handler(event_type, func)
873 return func
874
875 return decorator