Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/middleware/exceptions.py: 23%
66 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import typing
3from starlette._utils import is_async_callable
4from starlette.concurrency import run_in_threadpool
5from starlette.exceptions import HTTPException, WebSocketException
6from starlette.requests import Request
7from starlette.responses import PlainTextResponse, Response
8from starlette.types import ASGIApp, Message, Receive, Scope, Send
9from starlette.websockets import WebSocket
12class ExceptionMiddleware:
13 def __init__(
14 self,
15 app: ASGIApp,
16 handlers: typing.Optional[
17 typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]]
18 ] = None,
19 debug: bool = False,
20 ) -> None:
21 self.app = app
22 self.debug = debug # TODO: We ought to handle 404 cases if debug is set.
23 self._status_handlers: typing.Dict[int, typing.Callable] = {}
24 self._exception_handlers: typing.Dict[
25 typing.Type[Exception], typing.Callable
26 ] = {
27 HTTPException: self.http_exception,
28 WebSocketException: self.websocket_exception,
29 }
30 if handlers is not None:
31 for key, value in handlers.items():
32 self.add_exception_handler(key, value)
34 def add_exception_handler(
35 self,
36 exc_class_or_status_code: typing.Union[int, typing.Type[Exception]],
37 handler: typing.Callable[[Request, Exception], Response],
38 ) -> None:
39 if isinstance(exc_class_or_status_code, int):
40 self._status_handlers[exc_class_or_status_code] = handler
41 else:
42 assert issubclass(exc_class_or_status_code, Exception)
43 self._exception_handlers[exc_class_or_status_code] = handler
45 def _lookup_exception_handler(
46 self, exc: Exception
47 ) -> typing.Optional[typing.Callable]:
48 for cls in type(exc).__mro__:
49 if cls in self._exception_handlers:
50 return self._exception_handlers[cls]
51 return None
53 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
54 if scope["type"] not in ("http", "websocket"):
55 await self.app(scope, receive, send)
56 return
58 response_started = False
60 async def sender(message: Message) -> None:
61 nonlocal response_started
63 if message["type"] == "http.response.start":
64 response_started = True
65 await send(message)
67 try:
68 await self.app(scope, receive, sender)
69 except Exception as exc:
70 handler = None
72 if isinstance(exc, HTTPException):
73 handler = self._status_handlers.get(exc.status_code)
75 if handler is None:
76 handler = self._lookup_exception_handler(exc)
78 if handler is None:
79 raise exc
81 if response_started:
82 msg = "Caught handled exception, but response already started."
83 raise RuntimeError(msg) from exc
85 if scope["type"] == "http":
86 request = Request(scope, receive=receive)
87 if is_async_callable(handler):
88 response = await handler(request, exc)
89 else:
90 response = await run_in_threadpool(handler, request, exc)
91 await response(scope, receive, sender)
92 elif scope["type"] == "websocket":
93 websocket = WebSocket(scope, receive=receive, send=send)
94 if is_async_callable(handler):
95 await handler(websocket, exc)
96 else:
97 await run_in_threadpool(handler, websocket, exc)
99 def http_exception(self, request: Request, exc: HTTPException) -> Response:
100 if exc.status_code in {204, 304}:
101 return Response(status_code=exc.status_code, headers=exc.headers)
102 return PlainTextResponse(
103 exc.detail, status_code=exc.status_code, headers=exc.headers
104 )
106 async def websocket_exception(
107 self, websocket: WebSocket, exc: WebSocketException
108 ) -> None:
109 await websocket.close(code=exc.code, reason=exc.reason)