1import typing as t
2
3from starlette.routing import Router
4from starlette.types import ASGIApp, Receive, Scope, Send
5
6Lifespan = t.Callable[[t.Any], t.AsyncContextManager]
7
8
9class LifespanMiddleware:
10 """
11 Middleware that adds support for Starlette lifespan handlers
12 (https://www.starlette.io/lifespan/).
13 """
14
15 def __init__(self, next_app: ASGIApp, *, lifespan: t.Optional[Lifespan]) -> None:
16 self.next_app = next_app
17 self._lifespan = lifespan
18 # Leverage a Starlette Router for lifespan handling only
19 self.router = Router(lifespan=lifespan)
20
21 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
22 # If no lifespan is registered, pass to next app so it can be handled downstream.
23 if scope["type"] == "lifespan" and self._lifespan:
24 await self.router(scope, receive, send)
25 else:
26 await self.next_app(scope, receive, send)