1import abc
2import logging
3import typing as t
4from collections import defaultdict
5
6from starlette.types import ASGIApp, Receive, Scope, Send
7
8from connexion.exceptions import MissingMiddleware, ResolverError
9from connexion.http_facts import METHODS
10from connexion.operations import AbstractOperation
11from connexion.resolver import Resolver
12from connexion.spec import Specification
13
14logger = logging.getLogger(__name__)
15
16ROUTING_CONTEXT = "connexion_routing"
17
18
19class SpecMiddleware(abc.ABC):
20 """Middlewares that need the specification(s) to be registered on them should inherit from this
21 base class"""
22
23 @abc.abstractmethod
24 def add_api(self, specification: Specification, **kwargs) -> t.Any:
25 """
26 Register an API represented by a single OpenAPI specification on this middleware.
27 Multiple APIs can be registered on a single middleware.
28 """
29
30 @abc.abstractmethod
31 async def __call__(self, scope: Scope, receive: Receive, send: Send):
32 pass
33
34
35class AbstractSpecAPI:
36 """Base API class with only minimal behavior related to the specification."""
37
38 def __init__(
39 self,
40 specification: Specification,
41 base_path: t.Optional[str] = None,
42 resolver: t.Optional[Resolver] = None,
43 uri_parser_class=None,
44 *args,
45 **kwargs,
46 ):
47 self.specification = specification
48 self.uri_parser_class = uri_parser_class
49
50 self._set_base_path(base_path)
51
52 self.resolver = resolver or Resolver()
53
54 def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
55 if base_path is not None:
56 # update spec to include user-provided base_path
57 self.specification.base_path = base_path
58 self.base_path = base_path
59 else:
60 self.base_path = self.specification.base_path
61
62
63OP = t.TypeVar("OP")
64"""Typevar representing an operation"""
65
66
67class AbstractRoutingAPI(AbstractSpecAPI, t.Generic[OP]):
68 """Base API class with shared functionality related to routing."""
69
70 def __init__(
71 self,
72 *args,
73 pythonic_params=False,
74 resolver_error_handler: t.Optional[t.Callable] = None,
75 **kwargs,
76 ) -> None:
77 super().__init__(*args, **kwargs)
78 self.pythonic_params = pythonic_params
79 self.resolver_error_handler = resolver_error_handler
80
81 self.add_paths()
82
83 def add_paths(self, paths: t.Optional[dict] = None) -> None:
84 """
85 Adds the paths defined in the specification as operations.
86 """
87 paths = t.cast(dict, paths or self.specification.get("paths", dict()))
88 for path, methods in paths.items():
89 logger.debug("Adding %s%s...", self.base_path, path)
90
91 for method in methods:
92 if method not in METHODS:
93 continue
94 try:
95 self.add_operation(path, method)
96 except ResolverError as err:
97 # If we have an error handler for resolver errors, add it as an operation.
98 # Otherwise treat it as any other error.
99 if self.resolver_error_handler is not None:
100 self._add_resolver_error_handler(method, path, err)
101 else:
102 self._handle_add_operation_error(path, method, err)
103 except Exception as e:
104 # All other relevant exceptions should be handled as well.
105 self._handle_add_operation_error(path, method, e)
106
107 def add_operation(self, path: str, method: str) -> None:
108 """
109 Adds one operation to the api.
110
111 This method uses the OperationID identify the module and function that will handle the operation
112
113 From Swagger Specification:
114
115 **OperationID**
116
117 A friendly name for the operation. The id MUST be unique among all operations described in the API.
118 Tools and libraries MAY use the operation id to uniquely identify an operation.
119 """
120 spec_operation_cls = self.specification.operation_cls
121 spec_operation = spec_operation_cls.from_spec(
122 self.specification,
123 path=path,
124 method=method,
125 resolver=self.resolver,
126 uri_parser_class=self.uri_parser_class,
127 )
128 operation = self.make_operation(spec_operation)
129 path, name = self._framework_path_and_name(spec_operation, path)
130 self._add_operation_internal(method, path, operation, name=name)
131
132 @abc.abstractmethod
133 def make_operation(self, operation: AbstractOperation) -> OP:
134 """Build an operation to register on the API."""
135
136 @staticmethod
137 def _framework_path_and_name(
138 operation: AbstractOperation, path: str
139 ) -> t.Tuple[str, str]:
140 """Prepare the framework path & name to register the operation on the API."""
141
142 @abc.abstractmethod
143 def _add_operation_internal(
144 self, method: str, path: str, operation: OP, name: str = None
145 ) -> None:
146 """
147 Adds the operation according to the user framework in use.
148 It will be used to register the operation on the user framework router.
149 """
150
151 def _add_resolver_error_handler(
152 self, method: str, path: str, err: ResolverError
153 ) -> None:
154 """
155 Adds a handler for ResolverError for the given method and path.
156 """
157 self.resolver_error_handler = t.cast(t.Callable, self.resolver_error_handler)
158 operation = self.resolver_error_handler(
159 err,
160 )
161 self._add_operation_internal(method, path, operation)
162
163 def _handle_add_operation_error(
164 self, path: str, method: str, exc: Exception
165 ) -> None:
166 url = f"{self.base_path}{path}"
167 error_msg = f"Failed to add operation for {method.upper()} {url}"
168 logger.error(error_msg)
169 raise exc from None
170
171
172class RoutedAPI(AbstractSpecAPI, t.Generic[OP]):
173 def __init__(
174 self,
175 specification: Specification,
176 *args,
177 next_app: ASGIApp,
178 **kwargs,
179 ) -> None:
180 super().__init__(specification, *args, **kwargs)
181 self.next_app = next_app
182 self.operations: t.MutableMapping[t.Optional[str], OP] = {}
183
184 def add_paths(self) -> None:
185 paths = self.specification.get("paths", {})
186 for path, methods in paths.items():
187 for method in methods:
188 if method not in METHODS:
189 continue
190 try:
191 self.add_operation(path, method)
192 except ResolverError:
193 # ResolverErrors are either raised or handled in routing middleware.
194 pass
195
196 def add_operation(self, path: str, method: str) -> None:
197 operation_spec_cls = self.specification.operation_cls
198 operation = operation_spec_cls.from_spec(
199 self.specification,
200 path=path,
201 method=method,
202 resolver=self.resolver,
203 uri_parser_class=self.uri_parser_class,
204 )
205 routed_operation = self.make_operation(operation)
206 self.operations[operation.operation_id] = routed_operation
207
208 @abc.abstractmethod
209 def make_operation(self, operation: AbstractOperation) -> OP:
210 """Create an operation of the `operation_cls` type."""
211 raise NotImplementedError
212
213
214API = t.TypeVar("API", bound="RoutedAPI")
215"""Typevar representing an API which subclasses RoutedAPI"""
216
217
218class RoutedMiddleware(SpecMiddleware, t.Generic[API]):
219 """Baseclass for middleware that wants to leverage the RoutingMiddleware to route requests to
220 its operations.
221
222 The RoutingMiddleware adds the operation_id to the ASGI scope. This middleware registers its
223 operations by operation_id at startup. At request time, the operation is fetched by an
224 operation_id lookup.
225 """
226
227 api_cls: t.Type[API]
228 """The subclass of RoutedAPI this middleware uses."""
229
230 def __init__(self, app: ASGIApp) -> None:
231 self.app = app
232 self.apis: t.Dict[str, t.List[API]] = defaultdict(list)
233
234 def add_api(self, specification: Specification, **kwargs) -> API:
235 api = self.api_cls(specification, next_app=self.app, **kwargs)
236 self.apis[api.base_path].append(api)
237 return api
238
239 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
240 """Fetches the operation related to the request and calls it."""
241 if scope["type"] != "http":
242 await self.app(scope, receive, send)
243 return
244
245 try:
246 connexion_context = scope["extensions"][ROUTING_CONTEXT]
247 except KeyError:
248 raise MissingMiddleware(
249 "Could not find routing information in scope. Please make sure "
250 "you have a routing middleware registered upstream. "
251 )
252 api_base_path = connexion_context.get("api_base_path")
253 if api_base_path is not None and api_base_path in self.apis:
254 for api in self.apis[api_base_path]:
255 operation_id = connexion_context.get("operation_id")
256 try:
257 operation = api.operations[operation_id]
258 except KeyError:
259 if operation_id is None:
260 logger.debug("Skipping operation without id.")
261 await self.app(scope, receive, send)
262 return
263 else:
264 return await operation(scope, receive, send)
265
266 raise MissingOperation("Encountered unknown operation_id.")
267
268 await self.app(scope, receive, send)
269
270
271class MissingOperation(Exception):
272 """Missing operation"""