1from __future__ import annotations
2
3import contextlib
4import functools
5import inspect
6import re
7import traceback
8import types
9import typing
10import warnings
11from contextlib import asynccontextmanager
12from enum import Enum
13
14from starlette._exception_handler import wrap_app_handling_exceptions
15from starlette._utils import get_route_path, is_async_callable
16from starlette.concurrency import run_in_threadpool
17from starlette.convertors import CONVERTOR_TYPES, Convertor
18from starlette.datastructures import URL, Headers, URLPath
19from starlette.exceptions import HTTPException
20from starlette.middleware import Middleware
21from starlette.requests import Request
22from starlette.responses import PlainTextResponse, RedirectResponse, Response
23from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
24from starlette.websockets import WebSocket, WebSocketClose
25
26
27class NoMatchFound(Exception):
28 """
29 Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
30 if no matching route exists.
31 """
32
33 def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None:
34 params = ", ".join(list(path_params.keys()))
35 super().__init__(f'No route exists for name "{name}" and params "{params}".')
36
37
38class Match(Enum):
39 NONE = 0
40 PARTIAL = 1
41 FULL = 2
42
43
44def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
45 """
46 Correctly determines if an object is a coroutine function,
47 including those wrapped in functools.partial objects.
48 """
49 warnings.warn(
50 "iscoroutinefunction_or_partial is deprecated, "
51 "and will be removed in a future release.",
52 DeprecationWarning,
53 )
54 while isinstance(obj, functools.partial):
55 obj = obj.func
56 return inspect.iscoroutinefunction(obj)
57
58
59def request_response(
60 func: typing.Callable[[Request], typing.Awaitable[Response] | Response],
61) -> ASGIApp:
62 """
63 Takes a function or coroutine `func(request) -> response`,
64 and returns an ASGI application.
65 """
66
67 async def app(scope: Scope, receive: Receive, send: Send) -> None:
68 request = Request(scope, receive, send)
69
70 async def app(scope: Scope, receive: Receive, send: Send) -> None:
71 if is_async_callable(func):
72 response = await func(request)
73 else:
74 response = await run_in_threadpool(func, request)
75 await response(scope, receive, send)
76
77 await wrap_app_handling_exceptions(app, request)(scope, receive, send)
78
79 return app
80
81
82def websocket_session(
83 func: typing.Callable[[WebSocket], typing.Awaitable[None]],
84) -> ASGIApp:
85 """
86 Takes a coroutine `func(session)`, and returns an ASGI application.
87 """
88 # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
89
90 async def app(scope: Scope, receive: Receive, send: Send) -> None:
91 session = WebSocket(scope, receive=receive, send=send)
92
93 async def app(scope: Scope, receive: Receive, send: Send) -> None:
94 await func(session)
95
96 await wrap_app_handling_exceptions(app, session)(scope, receive, send)
97
98 return app
99
100
101def get_name(endpoint: typing.Callable[..., typing.Any]) -> str:
102 if inspect.isroutine(endpoint) or inspect.isclass(endpoint):
103 return endpoint.__name__
104 return endpoint.__class__.__name__
105
106
107def replace_params(
108 path: str,
109 param_convertors: dict[str, Convertor[typing.Any]],
110 path_params: dict[str, str],
111) -> tuple[str, dict[str, str]]:
112 for key, value in list(path_params.items()):
113 if "{" + key + "}" in path:
114 convertor = param_convertors[key]
115 value = convertor.to_string(value)
116 path = path.replace("{" + key + "}", value)
117 path_params.pop(key)
118 return path, path_params
119
120
121# Match parameters in URL paths, eg. '{param}', and '{param:int}'
122PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
123
124
125def compile_path(
126 path: str,
127) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]:
128 """
129 Given a path string, like: "/{username:str}",
130 or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
131 of (regex, format, {param_name:convertor}).
132
133 regex: "/(?P<username>[^/]+)"
134 format: "/{username}"
135 convertors: {"username": StringConvertor()}
136 """
137 is_host = not path.startswith("/")
138
139 path_regex = "^"
140 path_format = ""
141 duplicated_params = set()
142
143 idx = 0
144 param_convertors = {}
145 for match in PARAM_REGEX.finditer(path):
146 param_name, convertor_type = match.groups("str")
147 convertor_type = convertor_type.lstrip(":")
148 assert (
149 convertor_type in CONVERTOR_TYPES
150 ), f"Unknown path convertor '{convertor_type}'"
151 convertor = CONVERTOR_TYPES[convertor_type]
152
153 path_regex += re.escape(path[idx : match.start()])
154 path_regex += f"(?P<{param_name}>{convertor.regex})"
155
156 path_format += path[idx : match.start()]
157 path_format += "{%s}" % param_name
158
159 if param_name in param_convertors:
160 duplicated_params.add(param_name)
161
162 param_convertors[param_name] = convertor
163
164 idx = match.end()
165
166 if duplicated_params:
167 names = ", ".join(sorted(duplicated_params))
168 ending = "s" if len(duplicated_params) > 1 else ""
169 raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
170
171 if is_host:
172 # Align with `Host.matches()` behavior, which ignores port.
173 hostname = path[idx:].split(":")[0]
174 path_regex += re.escape(hostname) + "$"
175 else:
176 path_regex += re.escape(path[idx:]) + "$"
177
178 path_format += path[idx:]
179
180 return re.compile(path_regex), path_format, param_convertors
181
182
183class BaseRoute:
184 def matches(self, scope: Scope) -> tuple[Match, Scope]:
185 raise NotImplementedError() # pragma: no cover
186
187 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
188 raise NotImplementedError() # pragma: no cover
189
190 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
191 raise NotImplementedError() # pragma: no cover
192
193 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
194 """
195 A route may be used in isolation as a stand-alone ASGI app.
196 This is a somewhat contrived case, as they'll almost always be used
197 within a Router, but could be useful for some tooling and minimal apps.
198 """
199 match, child_scope = self.matches(scope)
200 if match == Match.NONE:
201 if scope["type"] == "http":
202 response = PlainTextResponse("Not Found", status_code=404)
203 await response(scope, receive, send)
204 elif scope["type"] == "websocket":
205 websocket_close = WebSocketClose()
206 await websocket_close(scope, receive, send)
207 return
208
209 scope.update(child_scope)
210 await self.handle(scope, receive, send)
211
212
213class Route(BaseRoute):
214 def __init__(
215 self,
216 path: str,
217 endpoint: typing.Callable[..., typing.Any],
218 *,
219 methods: list[str] | None = None,
220 name: str | None = None,
221 include_in_schema: bool = True,
222 middleware: typing.Sequence[Middleware] | None = None,
223 ) -> None:
224 assert path.startswith("/"), "Routed paths must start with '/'"
225 self.path = path
226 self.endpoint = endpoint
227 self.name = get_name(endpoint) if name is None else name
228 self.include_in_schema = include_in_schema
229
230 endpoint_handler = endpoint
231 while isinstance(endpoint_handler, functools.partial):
232 endpoint_handler = endpoint_handler.func
233 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
234 # Endpoint is function or method. Treat it as `func(request) -> response`.
235 self.app = request_response(endpoint)
236 if methods is None:
237 methods = ["GET"]
238 else:
239 # Endpoint is a class. Treat it as ASGI.
240 self.app = endpoint
241
242 if middleware is not None:
243 for cls, args, kwargs in reversed(middleware):
244 self.app = cls(app=self.app, *args, **kwargs)
245
246 if methods is None:
247 self.methods = None
248 else:
249 self.methods = {method.upper() for method in methods}
250 if "GET" in self.methods:
251 self.methods.add("HEAD")
252
253 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
254
255 def matches(self, scope: Scope) -> tuple[Match, Scope]:
256 path_params: dict[str, typing.Any]
257 if scope["type"] == "http":
258 route_path = get_route_path(scope)
259 match = self.path_regex.match(route_path)
260 if match:
261 matched_params = match.groupdict()
262 for key, value in matched_params.items():
263 matched_params[key] = self.param_convertors[key].convert(value)
264 path_params = dict(scope.get("path_params", {}))
265 path_params.update(matched_params)
266 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
267 if self.methods and scope["method"] not in self.methods:
268 return Match.PARTIAL, child_scope
269 else:
270 return Match.FULL, child_scope
271 return Match.NONE, {}
272
273 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
274 seen_params = set(path_params.keys())
275 expected_params = set(self.param_convertors.keys())
276
277 if name != self.name or seen_params != expected_params:
278 raise NoMatchFound(name, path_params)
279
280 path, remaining_params = replace_params(
281 self.path_format, self.param_convertors, path_params
282 )
283 assert not remaining_params
284 return URLPath(path=path, protocol="http")
285
286 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
287 if self.methods and scope["method"] not in self.methods:
288 headers = {"Allow": ", ".join(self.methods)}
289 if "app" in scope:
290 raise HTTPException(status_code=405, headers=headers)
291 else:
292 response = PlainTextResponse(
293 "Method Not Allowed", status_code=405, headers=headers
294 )
295 await response(scope, receive, send)
296 else:
297 await self.app(scope, receive, send)
298
299 def __eq__(self, other: typing.Any) -> bool:
300 return (
301 isinstance(other, Route)
302 and self.path == other.path
303 and self.endpoint == other.endpoint
304 and self.methods == other.methods
305 )
306
307 def __repr__(self) -> str:
308 class_name = self.__class__.__name__
309 methods = sorted(self.methods or [])
310 path, name = self.path, self.name
311 return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})"
312
313
314class WebSocketRoute(BaseRoute):
315 def __init__(
316 self,
317 path: str,
318 endpoint: typing.Callable[..., typing.Any],
319 *,
320 name: str | None = None,
321 middleware: typing.Sequence[Middleware] | None = None,
322 ) -> None:
323 assert path.startswith("/"), "Routed paths must start with '/'"
324 self.path = path
325 self.endpoint = endpoint
326 self.name = get_name(endpoint) if name is None else name
327
328 endpoint_handler = endpoint
329 while isinstance(endpoint_handler, functools.partial):
330 endpoint_handler = endpoint_handler.func
331 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
332 # Endpoint is function or method. Treat it as `func(websocket)`.
333 self.app = websocket_session(endpoint)
334 else:
335 # Endpoint is a class. Treat it as ASGI.
336 self.app = endpoint
337
338 if middleware is not None:
339 for cls, args, kwargs in reversed(middleware):
340 self.app = cls(app=self.app, *args, **kwargs)
341
342 self.path_regex, self.path_format, self.param_convertors = compile_path(path)
343
344 def matches(self, scope: Scope) -> tuple[Match, Scope]:
345 path_params: dict[str, typing.Any]
346 if scope["type"] == "websocket":
347 route_path = get_route_path(scope)
348 match = self.path_regex.match(route_path)
349 if match:
350 matched_params = match.groupdict()
351 for key, value in matched_params.items():
352 matched_params[key] = self.param_convertors[key].convert(value)
353 path_params = dict(scope.get("path_params", {}))
354 path_params.update(matched_params)
355 child_scope = {"endpoint": self.endpoint, "path_params": path_params}
356 return Match.FULL, child_scope
357 return Match.NONE, {}
358
359 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
360 seen_params = set(path_params.keys())
361 expected_params = set(self.param_convertors.keys())
362
363 if name != self.name or seen_params != expected_params:
364 raise NoMatchFound(name, path_params)
365
366 path, remaining_params = replace_params(
367 self.path_format, self.param_convertors, path_params
368 )
369 assert not remaining_params
370 return URLPath(path=path, protocol="websocket")
371
372 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
373 await self.app(scope, receive, send)
374
375 def __eq__(self, other: typing.Any) -> bool:
376 return (
377 isinstance(other, WebSocketRoute)
378 and self.path == other.path
379 and self.endpoint == other.endpoint
380 )
381
382 def __repr__(self) -> str:
383 return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})"
384
385
386class Mount(BaseRoute):
387 def __init__(
388 self,
389 path: str,
390 app: ASGIApp | None = None,
391 routes: typing.Sequence[BaseRoute] | None = None,
392 name: str | None = None,
393 *,
394 middleware: typing.Sequence[Middleware] | None = None,
395 ) -> None:
396 assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
397 assert (
398 app is not None or routes is not None
399 ), "Either 'app=...', or 'routes=' must be specified"
400 self.path = path.rstrip("/")
401 if app is not None:
402 self._base_app: ASGIApp = app
403 else:
404 self._base_app = Router(routes=routes)
405 self.app = self._base_app
406 if middleware is not None:
407 for cls, args, kwargs in reversed(middleware):
408 self.app = cls(app=self.app, *args, **kwargs)
409 self.name = name
410 self.path_regex, self.path_format, self.param_convertors = compile_path(
411 self.path + "/{path:path}"
412 )
413
414 @property
415 def routes(self) -> list[BaseRoute]:
416 return getattr(self._base_app, "routes", [])
417
418 def matches(self, scope: Scope) -> tuple[Match, Scope]:
419 path_params: dict[str, typing.Any]
420 if scope["type"] in ("http", "websocket"):
421 root_path = scope.get("root_path", "")
422 route_path = get_route_path(scope)
423 match = self.path_regex.match(route_path)
424 if match:
425 matched_params = match.groupdict()
426 for key, value in matched_params.items():
427 matched_params[key] = self.param_convertors[key].convert(value)
428 remaining_path = "/" + matched_params.pop("path")
429 matched_path = route_path[: -len(remaining_path)]
430 path_params = dict(scope.get("path_params", {}))
431 path_params.update(matched_params)
432 child_scope = {
433 "path_params": path_params,
434 # app_root_path will only be set at the top level scope,
435 # initialized with the (optional) value of a root_path
436 # set above/before Starlette. And even though any
437 # mount will have its own child scope with its own respective
438 # root_path, the app_root_path will always be available in all
439 # the child scopes with the same top level value because it's
440 # set only once here with a default, any other child scope will
441 # just inherit that app_root_path default value stored in the
442 # scope. All this is needed to support Request.url_for(), as it
443 # uses the app_root_path to build the URL path.
444 "app_root_path": scope.get("app_root_path", root_path),
445 "root_path": root_path + matched_path,
446 "endpoint": self.app,
447 }
448 return Match.FULL, child_scope
449 return Match.NONE, {}
450
451 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
452 if self.name is not None and name == self.name and "path" in path_params:
453 # 'name' matches "<mount_name>".
454 path_params["path"] = path_params["path"].lstrip("/")
455 path, remaining_params = replace_params(
456 self.path_format, self.param_convertors, path_params
457 )
458 if not remaining_params:
459 return URLPath(path=path)
460 elif self.name is None or name.startswith(self.name + ":"):
461 if self.name is None:
462 # No mount name.
463 remaining_name = name
464 else:
465 # 'name' matches "<mount_name>:<child_name>".
466 remaining_name = name[len(self.name) + 1 :]
467 path_kwarg = path_params.get("path")
468 path_params["path"] = ""
469 path_prefix, remaining_params = replace_params(
470 self.path_format, self.param_convertors, path_params
471 )
472 if path_kwarg is not None:
473 remaining_params["path"] = path_kwarg
474 for route in self.routes or []:
475 try:
476 url = route.url_path_for(remaining_name, **remaining_params)
477 return URLPath(
478 path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
479 )
480 except NoMatchFound:
481 pass
482 raise NoMatchFound(name, path_params)
483
484 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
485 await self.app(scope, receive, send)
486
487 def __eq__(self, other: typing.Any) -> bool:
488 return (
489 isinstance(other, Mount)
490 and self.path == other.path
491 and self.app == other.app
492 )
493
494 def __repr__(self) -> str:
495 class_name = self.__class__.__name__
496 name = self.name or ""
497 return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})"
498
499
500class Host(BaseRoute):
501 def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None:
502 assert not host.startswith("/"), "Host must not start with '/'"
503 self.host = host
504 self.app = app
505 self.name = name
506 self.host_regex, self.host_format, self.param_convertors = compile_path(host)
507
508 @property
509 def routes(self) -> list[BaseRoute]:
510 return getattr(self.app, "routes", [])
511
512 def matches(self, scope: Scope) -> tuple[Match, Scope]:
513 if scope["type"] in ("http", "websocket"):
514 headers = Headers(scope=scope)
515 host = headers.get("host", "").split(":")[0]
516 match = self.host_regex.match(host)
517 if match:
518 matched_params = match.groupdict()
519 for key, value in matched_params.items():
520 matched_params[key] = self.param_convertors[key].convert(value)
521 path_params = dict(scope.get("path_params", {}))
522 path_params.update(matched_params)
523 child_scope = {"path_params": path_params, "endpoint": self.app}
524 return Match.FULL, child_scope
525 return Match.NONE, {}
526
527 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
528 if self.name is not None and name == self.name and "path" in path_params:
529 # 'name' matches "<mount_name>".
530 path = path_params.pop("path")
531 host, remaining_params = replace_params(
532 self.host_format, self.param_convertors, path_params
533 )
534 if not remaining_params:
535 return URLPath(path=path, host=host)
536 elif self.name is None or name.startswith(self.name + ":"):
537 if self.name is None:
538 # No mount name.
539 remaining_name = name
540 else:
541 # 'name' matches "<mount_name>:<child_name>".
542 remaining_name = name[len(self.name) + 1 :]
543 host, remaining_params = replace_params(
544 self.host_format, self.param_convertors, path_params
545 )
546 for route in self.routes or []:
547 try:
548 url = route.url_path_for(remaining_name, **remaining_params)
549 return URLPath(path=str(url), protocol=url.protocol, host=host)
550 except NoMatchFound:
551 pass
552 raise NoMatchFound(name, path_params)
553
554 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
555 await self.app(scope, receive, send)
556
557 def __eq__(self, other: typing.Any) -> bool:
558 return (
559 isinstance(other, Host)
560 and self.host == other.host
561 and self.app == other.app
562 )
563
564 def __repr__(self) -> str:
565 class_name = self.__class__.__name__
566 name = self.name or ""
567 return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})"
568
569
570_T = typing.TypeVar("_T")
571
572
573class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
574 def __init__(self, cm: typing.ContextManager[_T]):
575 self._cm = cm
576
577 async def __aenter__(self) -> _T:
578 return self._cm.__enter__()
579
580 async def __aexit__(
581 self,
582 exc_type: type[BaseException] | None,
583 exc_value: BaseException | None,
584 traceback: types.TracebackType | None,
585 ) -> bool | None:
586 return self._cm.__exit__(exc_type, exc_value, traceback)
587
588
589def _wrap_gen_lifespan_context(
590 lifespan_context: typing.Callable[
591 [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any]
592 ],
593) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]:
594 cmgr = contextlib.contextmanager(lifespan_context)
595
596 @functools.wraps(cmgr)
597 def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]:
598 return _AsyncLiftContextManager(cmgr(app))
599
600 return wrapper
601
602
603class _DefaultLifespan:
604 def __init__(self, router: Router):
605 self._router = router
606
607 async def __aenter__(self) -> None:
608 await self._router.startup()
609
610 async def __aexit__(self, *exc_info: object) -> None:
611 await self._router.shutdown()
612
613 def __call__(self: _T, app: object) -> _T:
614 return self
615
616
617class Router:
618 def __init__(
619 self,
620 routes: typing.Sequence[BaseRoute] | None = None,
621 redirect_slashes: bool = True,
622 default: ASGIApp | None = None,
623 on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
624 on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None,
625 # the generic to Lifespan[AppType] is the type of the top level application
626 # which the router cannot know statically, so we use typing.Any
627 lifespan: Lifespan[typing.Any] | None = None,
628 *,
629 middleware: typing.Sequence[Middleware] | None = None,
630 ) -> None:
631 self.routes = [] if routes is None else list(routes)
632 self.redirect_slashes = redirect_slashes
633 self.default = self.not_found if default is None else default
634 self.on_startup = [] if on_startup is None else list(on_startup)
635 self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
636
637 if on_startup or on_shutdown:
638 warnings.warn(
639 "The on_startup and on_shutdown parameters are deprecated, and they "
640 "will be removed on version 1.0. Use the lifespan parameter instead. "
641 "See more about it on https://www.starlette.io/lifespan/.",
642 DeprecationWarning,
643 )
644 if lifespan:
645 warnings.warn(
646 "The `lifespan` parameter cannot be used with `on_startup` or "
647 "`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
648 "ignored."
649 )
650
651 if lifespan is None:
652 self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self)
653
654 elif inspect.isasyncgenfunction(lifespan):
655 warnings.warn(
656 "async generator function lifespans are deprecated, "
657 "use an @contextlib.asynccontextmanager function instead",
658 DeprecationWarning,
659 )
660 self.lifespan_context = asynccontextmanager(
661 lifespan,
662 )
663 elif inspect.isgeneratorfunction(lifespan):
664 warnings.warn(
665 "generator function lifespans are deprecated, "
666 "use an @contextlib.asynccontextmanager function instead",
667 DeprecationWarning,
668 )
669 self.lifespan_context = _wrap_gen_lifespan_context(
670 lifespan,
671 )
672 else:
673 self.lifespan_context = lifespan
674
675 self.middleware_stack = self.app
676 if middleware:
677 for cls, args, kwargs in reversed(middleware):
678 self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
679
680 async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
681 if scope["type"] == "websocket":
682 websocket_close = WebSocketClose()
683 await websocket_close(scope, receive, send)
684 return
685
686 # If we're running inside a starlette application then raise an
687 # exception, so that the configurable exception handler can deal with
688 # returning the response. For plain ASGI apps, just return the response.
689 if "app" in scope:
690 raise HTTPException(status_code=404)
691 else:
692 response = PlainTextResponse("Not Found", status_code=404)
693 await response(scope, receive, send)
694
695 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath:
696 for route in self.routes:
697 try:
698 return route.url_path_for(name, **path_params)
699 except NoMatchFound:
700 pass
701 raise NoMatchFound(name, path_params)
702
703 async def startup(self) -> None:
704 """
705 Run any `.on_startup` event handlers.
706 """
707 for handler in self.on_startup:
708 if is_async_callable(handler):
709 await handler()
710 else:
711 handler()
712
713 async def shutdown(self) -> None:
714 """
715 Run any `.on_shutdown` event handlers.
716 """
717 for handler in self.on_shutdown:
718 if is_async_callable(handler):
719 await handler()
720 else:
721 handler()
722
723 async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
724 """
725 Handle ASGI lifespan messages, which allows us to manage application
726 startup and shutdown events.
727 """
728 started = False
729 app: typing.Any = scope.get("app")
730 await receive()
731 try:
732 async with self.lifespan_context(app) as maybe_state:
733 if maybe_state is not None:
734 if "state" not in scope:
735 raise RuntimeError(
736 'The server does not support "state" in the lifespan scope.'
737 )
738 scope["state"].update(maybe_state)
739 await send({"type": "lifespan.startup.complete"})
740 started = True
741 await receive()
742 except BaseException:
743 exc_text = traceback.format_exc()
744 if started:
745 await send({"type": "lifespan.shutdown.failed", "message": exc_text})
746 else:
747 await send({"type": "lifespan.startup.failed", "message": exc_text})
748 raise
749 else:
750 await send({"type": "lifespan.shutdown.complete"})
751
752 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
753 """
754 The main entry point to the Router class.
755 """
756 await self.middleware_stack(scope, receive, send)
757
758 async def app(self, scope: Scope, receive: Receive, send: Send) -> None:
759 assert scope["type"] in ("http", "websocket", "lifespan")
760
761 if "router" not in scope:
762 scope["router"] = self
763
764 if scope["type"] == "lifespan":
765 await self.lifespan(scope, receive, send)
766 return
767
768 partial = None
769
770 for route in self.routes:
771 # Determine if any route matches the incoming scope,
772 # and hand over to the matching route if found.
773 match, child_scope = route.matches(scope)
774 if match == Match.FULL:
775 scope.update(child_scope)
776 await route.handle(scope, receive, send)
777 return
778 elif match == Match.PARTIAL and partial is None:
779 partial = route
780 partial_scope = child_scope
781
782 if partial is not None:
783 # Handle partial matches. These are cases where an endpoint is
784 # able to handle the request, but is not a preferred option.
785 # We use this in particular to deal with "405 Method Not Allowed".
786 scope.update(partial_scope)
787 await partial.handle(scope, receive, send)
788 return
789
790 route_path = get_route_path(scope)
791 if scope["type"] == "http" and self.redirect_slashes and route_path != "/":
792 redirect_scope = dict(scope)
793 if route_path.endswith("/"):
794 redirect_scope["path"] = redirect_scope["path"].rstrip("/")
795 else:
796 redirect_scope["path"] = redirect_scope["path"] + "/"
797
798 for route in self.routes:
799 match, child_scope = route.matches(redirect_scope)
800 if match != Match.NONE:
801 redirect_url = URL(scope=redirect_scope)
802 response = RedirectResponse(url=str(redirect_url))
803 await response(scope, receive, send)
804 return
805
806 await self.default(scope, receive, send)
807
808 def __eq__(self, other: typing.Any) -> bool:
809 return isinstance(other, Router) and self.routes == other.routes
810
811 def mount(
812 self, path: str, app: ASGIApp, name: str | None = None
813 ) -> None: # pragma: nocover
814 route = Mount(path, app=app, name=name)
815 self.routes.append(route)
816
817 def host(
818 self, host: str, app: ASGIApp, name: str | None = None
819 ) -> None: # pragma: no cover
820 route = Host(host, app=app, name=name)
821 self.routes.append(route)
822
823 def add_route(
824 self,
825 path: str,
826 endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response],
827 methods: list[str] | None = None,
828 name: str | None = None,
829 include_in_schema: bool = True,
830 ) -> None: # pragma: nocover
831 route = Route(
832 path,
833 endpoint=endpoint,
834 methods=methods,
835 name=name,
836 include_in_schema=include_in_schema,
837 )
838 self.routes.append(route)
839
840 def add_websocket_route(
841 self,
842 path: str,
843 endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
844 name: str | None = None,
845 ) -> None: # pragma: no cover
846 route = WebSocketRoute(path, endpoint=endpoint, name=name)
847 self.routes.append(route)
848
849 def route(
850 self,
851 path: str,
852 methods: list[str] | None = None,
853 name: str | None = None,
854 include_in_schema: bool = True,
855 ) -> typing.Callable: # type: ignore[type-arg]
856 """
857 We no longer document this decorator style API, and its usage is discouraged.
858 Instead you should use the following approach:
859
860 >>> routes = [Route(path, endpoint=...), ...]
861 >>> app = Starlette(routes=routes)
862 """
863 warnings.warn(
864 "The `route` decorator is deprecated, and will be removed in version 1.0.0."
865 "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501
866 DeprecationWarning,
867 )
868
869 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
870 self.add_route(
871 path,
872 func,
873 methods=methods,
874 name=name,
875 include_in_schema=include_in_schema,
876 )
877 return func
878
879 return decorator
880
881 def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg]
882 """
883 We no longer document this decorator style API, and its usage is discouraged.
884 Instead you should use the following approach:
885
886 >>> routes = [WebSocketRoute(path, endpoint=...), ...]
887 >>> app = Starlette(routes=routes)
888 """
889 warnings.warn(
890 "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501
891 "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501
892 DeprecationWarning,
893 )
894
895 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
896 self.add_websocket_route(path, func, name=name)
897 return func
898
899 return decorator
900
901 def add_event_handler(
902 self, event_type: str, func: typing.Callable[[], typing.Any]
903 ) -> None: # pragma: no cover
904 assert event_type in ("startup", "shutdown")
905
906 if event_type == "startup":
907 self.on_startup.append(func)
908 else:
909 self.on_shutdown.append(func)
910
911 def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg]
912 warnings.warn(
913 "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501
914 "Refer to https://www.starlette.io/lifespan/ for recommended approach.",
915 DeprecationWarning,
916 )
917
918 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501
919 self.add_event_handler(event_type, func)
920 return func
921
922 return decorator