1import typing as t
2from contextvars import ContextVar
3
4import starlette.convertors
5from starlette.routing import Router
6from starlette.types import ASGIApp, Receive, Scope, Send
7
8from connexion.frameworks import starlette as starlette_utils
9from connexion.middleware.abstract import (
10 ROUTING_CONTEXT,
11 AbstractRoutingAPI,
12 SpecMiddleware,
13)
14from connexion.operations import AbstractOperation
15from connexion.resolver import Resolver
16from connexion.spec import Specification
17
18_scope: ContextVar[dict] = ContextVar("SCOPE")
19
20
21class RoutingOperation:
22 def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None:
23 self.operation_id = operation_id
24 self.next_app = next_app
25
26 @classmethod
27 def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp):
28 return cls(operation.operation_id, next_app)
29
30 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
31 """Attach operation to scope and pass it to the next app"""
32 original_scope = _scope.get()
33 # Pass resolved path params along
34 original_scope.setdefault("path_params", {}).update(
35 scope.get("path_params", {})
36 )
37
38 def get_root_path(scope: Scope) -> str:
39 return scope.get("route_root_path", scope.get("root_path", ""))
40
41 api_base_path = get_root_path(scope)[len(get_root_path(original_scope)) :]
42
43 extensions = original_scope.setdefault("extensions", {})
44 connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {})
45 connexion_routing.update(
46 {"api_base_path": api_base_path, "operation_id": self.operation_id}
47 )
48 await self.next_app(original_scope, receive, send)
49
50
51class RoutingAPI(AbstractRoutingAPI):
52 def __init__(
53 self,
54 specification: Specification,
55 *,
56 next_app: ASGIApp,
57 base_path: t.Optional[str] = None,
58 arguments: t.Optional[dict] = None,
59 resolver: t.Optional[Resolver] = None,
60 resolver_error_handler: t.Optional[t.Callable] = None,
61 debug: bool = False,
62 **kwargs,
63 ) -> None:
64 """API implementation on top of Starlette Router for Connexion middleware."""
65 self.next_app = next_app
66 self.router = Router(default=RoutingOperation(None, next_app))
67
68 super().__init__(
69 specification,
70 base_path=base_path,
71 arguments=arguments,
72 resolver=resolver,
73 resolver_error_handler=resolver_error_handler,
74 debug=debug,
75 **kwargs,
76 )
77
78 def make_operation(self, operation: AbstractOperation) -> RoutingOperation:
79 return RoutingOperation.from_operation(operation, next_app=self.next_app)
80
81 @staticmethod
82 def _framework_path_and_name(
83 operation: AbstractOperation, path: str
84 ) -> t.Tuple[str, str]:
85 types = operation.get_path_parameter_types()
86 starlette_path = starlette_utils.starlettify_path(path, types)
87 return starlette_path, starlette_path
88
89 def _add_operation_internal(
90 self, method: str, path: str, operation: RoutingOperation, name: str = None
91 ) -> None:
92 self.router.add_route(path, operation, methods=[method])
93
94
95class RoutingMiddleware(SpecMiddleware):
96 def __init__(self, app: ASGIApp) -> None:
97 """Middleware that resolves the Operation for an incoming request and attaches it to the
98 scope.
99
100 :param app: app to wrap in middleware.
101 """
102 self.app = app
103 # Pass unknown routes to next app
104 self.router = Router(default=RoutingOperation(None, self.app))
105 starlette.convertors.register_url_convertor(
106 "float", starlette_utils.FloatConverter()
107 )
108 starlette.convertors.register_url_convertor(
109 "int", starlette_utils.IntegerConverter()
110 )
111
112 def add_api(
113 self,
114 specification: Specification,
115 base_path: t.Optional[str] = None,
116 arguments: t.Optional[dict] = None,
117 **kwargs,
118 ) -> None:
119 """Add an API to the router based on a OpenAPI spec.
120
121 :param specification: OpenAPI spec.
122 :param base_path: Base path where to add this API.
123 :param arguments: Jinja arguments to replace in the spec.
124 """
125 api = RoutingAPI(
126 specification,
127 base_path=base_path,
128 arguments=arguments,
129 next_app=self.app,
130 **kwargs,
131 )
132
133 # If an API with the same base_path was already registered, chain the new API as its
134 # default. This way, if no matching route is found on the first API, the request is
135 # forwarded to the new API.
136 for route in self.router.routes:
137 if (
138 isinstance(route, starlette.routing.Mount)
139 and route.path == api.base_path
140 ):
141 route.app.default = api.router
142
143 self.router.mount(api.base_path, app=api.router)
144
145 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
146 """Route request to matching operation, and attach it to the scope before calling the
147 next app."""
148 if scope["type"] != "http":
149 await self.app(scope, receive, send)
150 return
151
152 _scope.set(scope.copy()) # type: ignore
153
154 await self.router(scope, receive, send)