1from __future__ import annotations
2
3import typing
4
5from starlette._exception_handler import (
6 ExceptionHandlers,
7 StatusHandlers,
8 wrap_app_handling_exceptions,
9)
10from starlette.exceptions import HTTPException, WebSocketException
11from starlette.requests import Request
12from starlette.responses import PlainTextResponse, Response
13from starlette.types import ASGIApp, Receive, Scope, Send
14from starlette.websockets import WebSocket
15
16
17class ExceptionMiddleware:
18 def __init__(
19 self,
20 app: ASGIApp,
21 handlers: typing.Mapping[
22 typing.Any, typing.Callable[[Request, Exception], Response]
23 ]
24 | None = None,
25 debug: bool = False,
26 ) -> None:
27 self.app = app
28 self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
29 self._status_handlers: StatusHandlers = {}
30 self._exception_handlers: ExceptionHandlers = {
31 HTTPException: self.http_exception,
32 WebSocketException: self.websocket_exception,
33 }
34 if handlers is not None:
35 for key, value in handlers.items():
36 self.add_exception_handler(key, value)
37
38 def add_exception_handler(
39 self,
40 exc_class_or_status_code: int | type[Exception],
41 handler: typing.Callable[[Request, Exception], Response],
42 ) -> None:
43 if isinstance(exc_class_or_status_code, int):
44 self._status_handlers[exc_class_or_status_code] = handler
45 else:
46 assert issubclass(exc_class_or_status_code, Exception)
47 self._exception_handlers[exc_class_or_status_code] = handler
48
49 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
50 if scope["type"] not in ("http", "websocket"):
51 await self.app(scope, receive, send)
52 return
53
54 scope["starlette.exception_handlers"] = (
55 self._exception_handlers,
56 self._status_handlers,
57 )
58
59 conn: Request | WebSocket
60 if scope["type"] == "http":
61 conn = Request(scope, receive, send)
62 else:
63 conn = WebSocket(scope, receive, send)
64
65 await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
66
67 def http_exception(self, request: Request, exc: Exception) -> Response:
68 assert isinstance(exc, HTTPException)
69 if exc.status_code in {204, 304}:
70 return Response(status_code=exc.status_code, headers=exc.headers)
71 return PlainTextResponse(
72 exc.detail, status_code=exc.status_code, headers=exc.headers
73 )
74
75 async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
76 assert isinstance(exc, WebSocketException)
77 await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover