1from __future__ import annotations
2
3import functools
4import inspect
5import sys
6from collections.abc import Awaitable, Generator
7from contextlib import AbstractAsyncContextManager, contextmanager
8from typing import Any, Callable, Generic, Protocol, TypeVar, overload
9
10from starlette.types import Scope
11
12if sys.version_info >= (3, 10): # pragma: no cover
13 from typing import TypeGuard
14else: # pragma: no cover
15 from typing_extensions import TypeGuard
16
17has_exceptiongroups = True
18if sys.version_info < (3, 11): # pragma: no cover
19 try:
20 from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
21 except ImportError:
22 has_exceptiongroups = False
23
24T = TypeVar("T")
25AwaitableCallable = Callable[..., Awaitable[T]]
26
27
28@overload
29def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ...
30
31
32@overload
33def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ...
34
35
36def is_async_callable(obj: Any) -> Any:
37 while isinstance(obj, functools.partial):
38 obj = obj.func
39
40 return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__))
41
42
43T_co = TypeVar("T_co", covariant=True)
44
45
46class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
47
48
49class SupportsAsyncClose(Protocol):
50 async def close(self) -> None: ... # pragma: no cover
51
52
53SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
54
55
56class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
57 __slots__ = ("aw", "entered")
58
59 def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
60 self.aw = aw
61
62 def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
63 return self.aw.__await__()
64
65 async def __aenter__(self) -> SupportsAsyncCloseType:
66 self.entered = await self.aw
67 return self.entered
68
69 async def __aexit__(self, *args: Any) -> None | bool:
70 await self.entered.close()
71 return None
72
73
74@contextmanager
75def collapse_excgroups() -> Generator[None, None, None]:
76 try:
77 yield
78 except BaseException as exc:
79 if has_exceptiongroups: # pragma: no cover
80 while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
81 exc = exc.exceptions[0]
82
83 raise exc
84
85
86def get_route_path(scope: Scope) -> str:
87 path: str = scope["path"]
88 root_path = scope.get("root_path", "")
89 if not root_path:
90 return path
91
92 if not path.startswith(root_path):
93 return path
94
95 if path == root_path:
96 return ""
97
98 if path[len(root_path)] == "/":
99 return path[len(root_path) :]
100
101 return path