Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/connexion/middleware/exceptions.py: 58%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

52 statements  

1import asyncio 

2import functools 

3import logging 

4import typing as t 

5 

6from starlette.concurrency import run_in_threadpool 

7from starlette.exceptions import HTTPException 

8from starlette.middleware.exceptions import ( 

9 ExceptionMiddleware as StarletteExceptionMiddleware, 

10) 

11from starlette.requests import Request as StarletteRequest 

12from starlette.responses import Response as StarletteResponse 

13from starlette.types import ASGIApp, Receive, Scope, Send 

14 

15from connexion import http_facts 

16from connexion.exceptions import InternalServerError, ProblemException, problem 

17from connexion.lifecycle import ConnexionRequest, ConnexionResponse 

18from connexion.types import MaybeAwaitable 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23def connexion_wrapper( 

24 handler: t.Callable[ 

25 [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse] 

26 ] 

27) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]: 

28 """Wrapper that translates Starlette requests to Connexion requests before passing 

29 them to the error handler, and translates the returned Connexion responses to 

30 Starlette responses.""" 

31 

32 @functools.wraps(handler) 

33 async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse: 

34 request = ConnexionRequest.from_starlette_request(request) 

35 

36 if asyncio.iscoroutinefunction(handler): 

37 response = await handler(request, exc) # type: ignore 

38 else: 

39 response = await run_in_threadpool(handler, request, exc) 

40 

41 while asyncio.iscoroutine(response): 

42 response = await response 

43 

44 return StarletteResponse( 

45 content=response.body, 

46 status_code=response.status_code, 

47 media_type=response.mimetype, 

48 headers=response.headers, 

49 ) 

50 

51 return wrapper 

52 

53 

54class ExceptionMiddleware(StarletteExceptionMiddleware): 

55 """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to 

56 existing connexion behavior.""" 

57 

58 def __init__(self, next_app: ASGIApp): 

59 super().__init__(next_app) 

60 self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore 

61 self.add_exception_handler(Exception, self.common_error_handler) 

62 

63 def add_exception_handler( 

64 self, 

65 exc_class_or_status_code: t.Union[int, t.Type[Exception]], 

66 handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse], 

67 ) -> None: 

68 super().add_exception_handler( 

69 exc_class_or_status_code, handler=connexion_wrapper(handler) 

70 ) 

71 

72 @staticmethod 

73 def problem_handler(_request: ConnexionRequest, exc: ProblemException): 

74 """Default handler for Connexion ProblemExceptions""" 

75 

76 if 400 <= exc.status <= 499: 

77 logger.warning("%r", exc) 

78 else: 

79 logger.error("%r", exc) 

80 

81 return exc.to_problem() 

82 

83 @staticmethod 

84 @connexion_wrapper 

85 def http_exception( 

86 _request: StarletteRequest, exc: HTTPException, **kwargs 

87 ) -> StarletteResponse: 

88 """Default handler for Starlette HTTPException""" 

89 

90 if 400 <= exc.status_code <= 499: 

91 logger.warning("%r", exc) 

92 else: 

93 logger.error("%r", exc) 

94 

95 return problem( 

96 title=http_facts.HTTP_STATUS_CODES.get(exc.status_code), 

97 detail=exc.detail, 

98 status=exc.status_code, 

99 headers=exc.headers, 

100 ) 

101 

102 @staticmethod 

103 def common_error_handler( 

104 _request: StarletteRequest, exc: Exception 

105 ) -> ConnexionResponse: 

106 """Default handler for any unhandled Exception""" 

107 logger.error("%r", exc, exc_info=exc) 

108 return InternalServerError().to_problem() 

109 

110 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

111 await super().__call__(scope, receive, send)