1from __future__ import annotations
2
3import functools
4import sys
5from collections.abc import Awaitable, Generator
6from contextlib import AbstractAsyncContextManager, contextmanager
7from typing import Any, Callable, Generic, Protocol, TypeVar, overload
8
9from starlette.types import Scope
10
11if sys.version_info >= (3, 13): # pragma: no cover
12 from inspect import iscoroutinefunction
13 from typing import TypeIs
14else: # pragma: no cover
15 from asyncio import iscoroutinefunction
16
17 from typing_extensions import TypeIs
18
19has_exceptiongroups = True
20if sys.version_info < (3, 11): # pragma: no cover
21 try:
22 from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
23 except ImportError:
24 has_exceptiongroups = False
25
26T = TypeVar("T")
27AwaitableCallable = Callable[..., Awaitable[T]]
28
29
30@overload
31def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
32
33
34@overload
35def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
36
37
38def is_async_callable(obj: Any) -> Any:
39 while isinstance(obj, functools.partial):
40 obj = obj.func
41
42 return iscoroutinefunction(obj) or (callable(obj) and iscoroutinefunction(obj.__call__))
43
44
45T_co = TypeVar("T_co", covariant=True)
46
47
48class AwaitableOrContextManager(Awaitable[T_co], AbstractAsyncContextManager[T_co], Protocol[T_co]): ...
49
50
51class SupportsAsyncClose(Protocol):
52 async def close(self) -> None: ... # pragma: no cover
53
54
55SupportsAsyncCloseType = TypeVar("SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False)
56
57
58class AwaitableOrContextManagerWrapper(Generic[SupportsAsyncCloseType]):
59 __slots__ = ("aw", "entered")
60
61 def __init__(self, aw: Awaitable[SupportsAsyncCloseType]) -> None:
62 self.aw = aw
63
64 def __await__(self) -> Generator[Any, None, SupportsAsyncCloseType]:
65 return self.aw.__await__()
66
67 async def __aenter__(self) -> SupportsAsyncCloseType:
68 self.entered = await self.aw
69 return self.entered
70
71 async def __aexit__(self, *args: Any) -> None | bool:
72 await self.entered.close()
73 return None
74
75
76@contextmanager
77def collapse_excgroups() -> Generator[None, None, None]:
78 try:
79 yield
80 except BaseException as exc:
81 if has_exceptiongroups: # pragma: no cover
82 while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
83 exc = exc.exceptions[0]
84
85 raise exc
86
87
88def get_route_path(scope: Scope) -> str:
89 path: str = scope["path"]
90 root_path = scope.get("root_path", "")
91 if not root_path:
92 return path
93
94 if not path.startswith(root_path):
95 return path
96
97 if path == root_path:
98 return ""
99
100 if path[len(root_path)] == "/":
101 return path[len(root_path) :]
102
103 return path