1from __future__ import annotations
2
3import typing
4
5from starlette._utils import is_async_callable
6from starlette.concurrency import run_in_threadpool
7from starlette.exceptions import HTTPException
8from starlette.requests import Request
9from starlette.types import (
10 ASGIApp,
11 ExceptionHandler,
12 HTTPExceptionHandler,
13 Message,
14 Receive,
15 Scope,
16 Send,
17 WebSocketExceptionHandler,
18)
19from starlette.websockets import WebSocket
20
21ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
22StatusHandlers = typing.Dict[int, ExceptionHandler]
23
24
25def _lookup_exception_handler(exc_handlers: ExceptionHandlers, exc: Exception) -> ExceptionHandler | None:
26 for cls in type(exc).__mro__:
27 if cls in exc_handlers:
28 return exc_handlers[cls]
29 return None
30
31
32def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
33 exception_handlers: ExceptionHandlers
34 status_handlers: StatusHandlers
35 try:
36 exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
37 except KeyError:
38 exception_handlers, status_handlers = {}, {}
39
40 async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
41 response_started = False
42
43 async def sender(message: Message) -> None:
44 nonlocal response_started
45
46 if message["type"] == "http.response.start":
47 response_started = True
48 await send(message)
49
50 try:
51 await app(scope, receive, sender)
52 except Exception as exc:
53 handler = None
54
55 if isinstance(exc, HTTPException):
56 handler = status_handlers.get(exc.status_code)
57
58 if handler is None:
59 handler = _lookup_exception_handler(exception_handlers, exc)
60
61 if handler is None:
62 raise exc
63
64 if response_started:
65 msg = "Caught handled exception, but response already started."
66 raise RuntimeError(msg) from exc
67
68 if scope["type"] == "http":
69 nonlocal conn
70 handler = typing.cast(HTTPExceptionHandler, handler)
71 conn = typing.cast(Request, conn)
72 if is_async_callable(handler):
73 response = await handler(conn, exc)
74 else:
75 response = await run_in_threadpool(handler, conn, exc)
76 await response(scope, receive, sender)
77 elif scope["type"] == "websocket":
78 handler = typing.cast(WebSocketExceptionHandler, handler)
79 conn = typing.cast(WebSocket, conn)
80 if is_async_callable(handler):
81 await handler(conn, exc)
82 else:
83 await run_in_threadpool(handler, conn, exc)
84
85 return wrapped_app