1from __future__ import annotations
2
3import sys
4from collections.abc import Iterator
5from typing import Any, Protocol
6
7if sys.version_info >= (3, 10): # pragma: no cover
8 from typing import ParamSpec
9else: # pragma: no cover
10 from typing_extensions import ParamSpec
11
12from starlette.types import ASGIApp
13
14P = ParamSpec("P")
15
16
17class _MiddlewareFactory(Protocol[P]):
18 def __call__(self, app: ASGIApp, /, *args: P.args, **kwargs: P.kwargs) -> ASGIApp: ... # pragma: no cover
19
20
21class Middleware:
22 def __init__(
23 self,
24 cls: _MiddlewareFactory[P],
25 *args: P.args,
26 **kwargs: P.kwargs,
27 ) -> None:
28 self.cls = cls
29 self.args = args
30 self.kwargs = kwargs
31
32 def __iter__(self) -> Iterator[Any]:
33 as_tuple = (self.cls, self.args, self.kwargs)
34 return iter(as_tuple)
35
36 def __repr__(self) -> str:
37 class_name = self.__class__.__name__
38 args_strings = [f"{value!r}" for value in self.args]
39 option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
40 name = getattr(self.cls, "__name__", "")
41 args_repr = ", ".join([name] + args_strings + option_strings)
42 return f"{class_name}({args_repr})"