1from __future__ import annotations
2
3import typing
4
5from starlette._utils import is_async_callable
6from starlette.concurrency import run_in_threadpool
7from starlette.exceptions import HTTPException
8from starlette.requests import Request
9from starlette.types import (
10 ASGIApp,
11 ExceptionHandler,
12 HTTPExceptionHandler,
13 Message,
14 Receive,
15 Scope,
16 Send,
17 WebSocketExceptionHandler,
18)
19from starlette.websockets import WebSocket
20
21ExceptionHandlers = typing.Dict[typing.Any, ExceptionHandler]
22StatusHandlers = typing.Dict[int, ExceptionHandler]
23
24
25def _lookup_exception_handler(
26 exc_handlers: ExceptionHandlers, exc: Exception
27) -> ExceptionHandler | None:
28 for cls in type(exc).__mro__:
29 if cls in exc_handlers:
30 return exc_handlers[cls]
31 return None
32
33
34def wrap_app_handling_exceptions(app: ASGIApp, conn: Request | WebSocket) -> ASGIApp:
35 exception_handlers: ExceptionHandlers
36 status_handlers: StatusHandlers
37 try:
38 exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"]
39 except KeyError:
40 exception_handlers, status_handlers = {}, {}
41
42 async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None:
43 response_started = False
44
45 async def sender(message: Message) -> None:
46 nonlocal response_started
47
48 if message["type"] == "http.response.start":
49 response_started = True
50 await send(message)
51
52 try:
53 await app(scope, receive, sender)
54 except Exception as exc:
55 handler = None
56
57 if isinstance(exc, HTTPException):
58 handler = status_handlers.get(exc.status_code)
59
60 if handler is None:
61 handler = _lookup_exception_handler(exception_handlers, exc)
62
63 if handler is None:
64 raise exc
65
66 if response_started:
67 msg = "Caught handled exception, but response already started."
68 raise RuntimeError(msg) from exc
69
70 if scope["type"] == "http":
71 nonlocal conn
72 handler = typing.cast(HTTPExceptionHandler, handler)
73 conn = typing.cast(Request, conn)
74 if is_async_callable(handler):
75 response = await handler(conn, exc)
76 else:
77 response = await run_in_threadpool(handler, conn, exc)
78 await response(scope, receive, sender)
79 elif scope["type"] == "websocket":
80 handler = typing.cast(WebSocketExceptionHandler, handler)
81 conn = typing.cast(WebSocket, conn)
82 if is_async_callable(handler):
83 await handler(conn, exc)
84 else:
85 await run_in_threadpool(handler, conn, exc)
86
87 return wrapped_app