1from __future__ import annotations
2
3from collections.abc import Mapping
4from typing import Any, Callable
5
6from starlette._exception_handler import (
7 ExceptionHandlers,
8 StatusHandlers,
9 wrap_app_handling_exceptions,
10)
11from starlette.exceptions import HTTPException, WebSocketException
12from starlette.requests import Request
13from starlette.responses import PlainTextResponse, Response
14from starlette.types import ASGIApp, Receive, Scope, Send
15from starlette.websockets import WebSocket
16
17
18class ExceptionMiddleware:
19 def __init__(
20 self,
21 app: ASGIApp,
22 handlers: Mapping[Any, Callable[[Request, Exception], Response]] | None = None,
23 debug: bool = False,
24 ) -> None:
25 self.app = app
26 self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
27 self._status_handlers: StatusHandlers = {}
28 self._exception_handlers: ExceptionHandlers = {
29 HTTPException: self.http_exception,
30 WebSocketException: self.websocket_exception,
31 }
32 if handlers is not None: # pragma: no branch
33 for key, value in handlers.items():
34 self.add_exception_handler(key, value)
35
36 def add_exception_handler(
37 self,
38 exc_class_or_status_code: int | type[Exception],
39 handler: Callable[[Request, Exception], Response],
40 ) -> None:
41 if isinstance(exc_class_or_status_code, int):
42 self._status_handlers[exc_class_or_status_code] = handler
43 else:
44 assert issubclass(exc_class_or_status_code, Exception)
45 self._exception_handlers[exc_class_or_status_code] = handler
46
47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
48 if scope["type"] not in ("http", "websocket"):
49 await self.app(scope, receive, send)
50 return
51
52 scope["starlette.exception_handlers"] = (
53 self._exception_handlers,
54 self._status_handlers,
55 )
56
57 conn: Request | WebSocket
58 if scope["type"] == "http":
59 conn = Request(scope, receive, send)
60 else:
61 conn = WebSocket(scope, receive, send)
62
63 await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
64
65 async def http_exception(self, request: Request, exc: Exception) -> Response:
66 assert isinstance(exc, HTTPException)
67 if exc.status_code in {204, 304}:
68 return Response(status_code=exc.status_code, headers=exc.headers)
69 return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers)
70
71 async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None:
72 assert isinstance(exc, WebSocketException)
73 await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover