1import asyncio
2import functools
3import logging
4import typing as t
5
6from starlette.concurrency import run_in_threadpool
7from starlette.exceptions import HTTPException
8from starlette.middleware.exceptions import (
9 ExceptionMiddleware as StarletteExceptionMiddleware,
10)
11from starlette.requests import Request as StarletteRequest
12from starlette.responses import Response as StarletteResponse
13from starlette.types import ASGIApp, Receive, Scope, Send
14
15from connexion import http_facts
16from connexion.exceptions import InternalServerError, ProblemException, problem
17from connexion.lifecycle import ConnexionRequest, ConnexionResponse
18from connexion.types import MaybeAwaitable
19
20logger = logging.getLogger(__name__)
21
22
23def connexion_wrapper(
24 handler: t.Callable[
25 [ConnexionRequest, Exception], MaybeAwaitable[ConnexionResponse]
26 ]
27) -> t.Callable[[StarletteRequest, Exception], t.Awaitable[StarletteResponse]]:
28 """Wrapper that translates Starlette requests to Connexion requests before passing
29 them to the error handler, and translates the returned Connexion responses to
30 Starlette responses."""
31
32 @functools.wraps(handler)
33 async def wrapper(request: StarletteRequest, exc: Exception) -> StarletteResponse:
34 request = ConnexionRequest.from_starlette_request(request)
35
36 if asyncio.iscoroutinefunction(handler):
37 response = await handler(request, exc) # type: ignore
38 else:
39 response = await run_in_threadpool(handler, request, exc)
40
41 while asyncio.iscoroutine(response):
42 response = await response
43
44 return StarletteResponse(
45 content=response.body,
46 status_code=response.status_code,
47 media_type=response.mimetype,
48 headers=response.headers,
49 )
50
51 return wrapper
52
53
54class ExceptionMiddleware(StarletteExceptionMiddleware):
55 """Subclass of starlette ExceptionMiddleware to change handling of HTTP exceptions to
56 existing connexion behavior."""
57
58 def __init__(self, next_app: ASGIApp):
59 super().__init__(next_app)
60 self.add_exception_handler(ProblemException, self.problem_handler) # type: ignore
61 self.add_exception_handler(Exception, self.common_error_handler)
62
63 def add_exception_handler(
64 self,
65 exc_class_or_status_code: t.Union[int, t.Type[Exception]],
66 handler: t.Callable[[ConnexionRequest, Exception], StarletteResponse],
67 ) -> None:
68 super().add_exception_handler(
69 exc_class_or_status_code, handler=connexion_wrapper(handler)
70 )
71
72 @staticmethod
73 def problem_handler(_request: ConnexionRequest, exc: ProblemException):
74 """Default handler for Connexion ProblemExceptions"""
75
76 if 400 <= exc.status <= 499:
77 logger.warning("%r", exc)
78 else:
79 logger.error("%r", exc)
80
81 return exc.to_problem()
82
83 @staticmethod
84 @connexion_wrapper
85 def http_exception(
86 _request: StarletteRequest, exc: HTTPException, **kwargs
87 ) -> StarletteResponse:
88 """Default handler for Starlette HTTPException"""
89
90 if 400 <= exc.status_code <= 499:
91 logger.warning("%r", exc)
92 else:
93 logger.error("%r", exc)
94
95 return problem(
96 title=http_facts.HTTP_STATUS_CODES.get(exc.status_code),
97 detail=exc.detail,
98 status=exc.status_code,
99 headers=exc.headers,
100 )
101
102 @staticmethod
103 def common_error_handler(
104 _request: StarletteRequest, exc: Exception
105 ) -> ConnexionResponse:
106 """Default handler for any unhandled Exception"""
107 logger.error("%r", exc, exc_info=exc)
108 return InternalServerError().to_problem()
109
110 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
111 await super().__call__(scope, receive, send)