Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/middleware/exceptions.py: 23%

66 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import typing 

2 

3from starlette._utils import is_async_callable 

4from starlette.concurrency import run_in_threadpool 

5from starlette.exceptions import HTTPException, WebSocketException 

6from starlette.requests import Request 

7from starlette.responses import PlainTextResponse, Response 

8from starlette.types import ASGIApp, Message, Receive, Scope, Send 

9from starlette.websockets import WebSocket 

10 

11 

12class ExceptionMiddleware: 

13 def __init__( 

14 self, 

15 app: ASGIApp, 

16 handlers: typing.Optional[ 

17 typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] 

18 ] = None, 

19 debug: bool = False, 

20 ) -> None: 

21 self.app = app 

22 self.debug = debug # TODO: We ought to handle 404 cases if debug is set. 

23 self._status_handlers: typing.Dict[int, typing.Callable] = {} 

24 self._exception_handlers: typing.Dict[ 

25 typing.Type[Exception], typing.Callable 

26 ] = { 

27 HTTPException: self.http_exception, 

28 WebSocketException: self.websocket_exception, 

29 } 

30 if handlers is not None: 

31 for key, value in handlers.items(): 

32 self.add_exception_handler(key, value) 

33 

34 def add_exception_handler( 

35 self, 

36 exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], 

37 handler: typing.Callable[[Request, Exception], Response], 

38 ) -> None: 

39 if isinstance(exc_class_or_status_code, int): 

40 self._status_handlers[exc_class_or_status_code] = handler 

41 else: 

42 assert issubclass(exc_class_or_status_code, Exception) 

43 self._exception_handlers[exc_class_or_status_code] = handler 

44 

45 def _lookup_exception_handler( 

46 self, exc: Exception 

47 ) -> typing.Optional[typing.Callable]: 

48 for cls in type(exc).__mro__: 

49 if cls in self._exception_handlers: 

50 return self._exception_handlers[cls] 

51 return None 

52 

53 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

54 if scope["type"] not in ("http", "websocket"): 

55 await self.app(scope, receive, send) 

56 return 

57 

58 response_started = False 

59 

60 async def sender(message: Message) -> None: 

61 nonlocal response_started 

62 

63 if message["type"] == "http.response.start": 

64 response_started = True 

65 await send(message) 

66 

67 try: 

68 await self.app(scope, receive, sender) 

69 except Exception as exc: 

70 handler = None 

71 

72 if isinstance(exc, HTTPException): 

73 handler = self._status_handlers.get(exc.status_code) 

74 

75 if handler is None: 

76 handler = self._lookup_exception_handler(exc) 

77 

78 if handler is None: 

79 raise exc 

80 

81 if response_started: 

82 msg = "Caught handled exception, but response already started." 

83 raise RuntimeError(msg) from exc 

84 

85 if scope["type"] == "http": 

86 request = Request(scope, receive=receive) 

87 if is_async_callable(handler): 

88 response = await handler(request, exc) 

89 else: 

90 response = await run_in_threadpool(handler, request, exc) 

91 await response(scope, receive, sender) 

92 elif scope["type"] == "websocket": 

93 websocket = WebSocket(scope, receive=receive, send=send) 

94 if is_async_callable(handler): 

95 await handler(websocket, exc) 

96 else: 

97 await run_in_threadpool(handler, websocket, exc) 

98 

99 def http_exception(self, request: Request, exc: HTTPException) -> Response: 

100 if exc.status_code in {204, 304}: 

101 return Response(status_code=exc.status_code, headers=exc.headers) 

102 return PlainTextResponse( 

103 exc.detail, status_code=exc.status_code, headers=exc.headers 

104 ) 

105 

106 async def websocket_exception( 

107 self, websocket: WebSocket, exc: WebSocketException 

108 ) -> None: 

109 await websocket.close(code=exc.code, reason=exc.reason)