1from __future__ import annotations
2
3import sys
4from typing import Any, Iterator, Protocol
5
6if sys.version_info >= (3, 10): # pragma: no cover
7 from typing import ParamSpec
8else: # pragma: no cover
9 from typing_extensions import ParamSpec
10
11from starlette.types import ASGIApp, Receive, Scope, Send
12
13P = ParamSpec("P")
14
15
16class _MiddlewareClass(Protocol[P]):
17 def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
18 ... # pragma: no cover
19
20 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
21 ... # pragma: no cover
22
23
24class Middleware:
25 def __init__(
26 self,
27 cls: type[_MiddlewareClass[P]],
28 *args: P.args,
29 **kwargs: P.kwargs,
30 ) -> None:
31 self.cls = cls
32 self.args = args
33 self.kwargs = kwargs
34
35 def __iter__(self) -> Iterator[Any]:
36 as_tuple = (self.cls, self.args, self.kwargs)
37 return iter(as_tuple)
38
39 def __repr__(self) -> str:
40 class_name = self.__class__.__name__
41 args_strings = [f"{value!r}" for value in self.args]
42 option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
43 args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
44 return f"{class_name}({args_repr})"