1from __future__ import annotations
2
3from typing import Any
4
5from starlette._utils import is_async_callable
6from starlette.concurrency import run_in_threadpool
7from starlette.exceptions import HTTPException
8from starlette.requests import Request
9from starlette.types import ASGIApp, ExceptionHandler, Message, Receive, Scope, Send
10from starlette.websockets import WebSocket
11
12ExceptionHandlers = dict[Any, ExceptionHandler]
13StatusHandlers = dict[int, ExceptionHandler]
14
15
16def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
17 for cls in type(exc).__mro__:
18 if cls in exc_handlers:
19 return exc_handlers[cls]
20 return None
21
22
23def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
24 exception_handlers: ExceptionHandlers
25 status_handlers: StatusHandlers
26 try:
27 exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
28 except KeyError:
29 exception_handlers, status_handlers = {}, {}
30
31 async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
32 response_started = False
33
34 async def sender(message: Message) -> None:
35 nonlocal response_started
36
37 if message["type"] == "http.response.start":
38 response_started = True
39 await send(message)
40
41 try:
42 await app(scope, receive, sender)
43 except Exception as exc:
44 handler = None
45
46 if isinstance(exc, HTTPException):
47 handler = status_handlers.get(exc.status_code)
48
49 if handler is None:
50 handler = _lookup_exception_handler(exception_handlers, exc)
51
52 if handler is None:
53 raise exc
54
55 if response_started:
56 raise RuntimeError("Caught handled exception, but response already started.") from exc
57
58 if is_async_callable(handler):
59 response = await handler(conn, exc)
60 else:
61 response = await run_in_threadpool(handler, conn, exc) # type: ignore
62 if response is not None:
63 await response(scope, receive, sender)
64
65 return wrapped_app