1import json
2import logging
3import re
4import typing as t
5from contextvars import ContextVar
6
7from starlette.requests import Request as StarletteRequest
8from starlette.responses import RedirectResponse
9from starlette.responses import Response as StarletteResponse
10from starlette.routing import Router
11from starlette.staticfiles import StaticFiles
12from starlette.templating import Jinja2Templates
13from starlette.types import ASGIApp, Receive, Scope, Send
14
15from connexion.jsonifier import Jsonifier
16from connexion.middleware import SpecMiddleware
17from connexion.middleware.abstract import AbstractSpecAPI
18from connexion.options import SwaggerUIConfig, SwaggerUIOptions
19from connexion.spec import Specification
20from connexion.utils import yamldumper
21
22logger = logging.getLogger("connexion.middleware.swagger_ui")
23
24
25_original_scope: ContextVar[Scope] = ContextVar("SCOPE")
26
27
28class SwaggerUIAPI(AbstractSpecAPI):
29 def __init__(
30 self,
31 *args,
32 default: ASGIApp,
33 swagger_ui_options: t.Optional[SwaggerUIOptions] = None,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.router = Router(default=default)
39 self.options = SwaggerUIConfig(
40 swagger_ui_options, oas_version=self.specification.version
41 )
42
43 if self.options.openapi_spec_available:
44 self.add_openapi_json()
45 self.add_openapi_yaml()
46
47 if self.options.swagger_ui_available:
48 self.add_swagger_ui()
49
50 self._templates = Jinja2Templates(
51 directory=str(self.options.swagger_ui_template_dir)
52 )
53
54 @staticmethod
55 def normalize_string(string):
56 return re.sub(r"[^a-zA-Z0-9]", "_", string.strip("/"))
57
58 def _base_path_for_prefix(self, request: StarletteRequest) -> str:
59 """
60 returns a modified basePath which includes the incoming root_path.
61 """
62 return request.scope.get(
63 "route_root_path", request.scope.get("root_path", "")
64 ).rstrip("/")
65
66 def _spec_for_prefix(self, request) -> dict:
67 """
68 returns a spec with a modified basePath / servers block
69 which corresponds to the incoming request path.
70 This is needed when behind a path-altering reverse proxy.
71 """
72 base_path = self._base_path_for_prefix(request)
73 return self.specification.with_base_path(base_path).raw
74
75 def add_openapi_json(self):
76 """
77 Adds openapi json to {base_path}/openapi.json
78 (or {base_path}/swagger.json for swagger2)
79 """
80 logger.info(
81 "Adding spec json: %s%s", self.base_path, self.options.openapi_spec_path
82 )
83 self.router.add_route(
84 methods=["GET"],
85 path=self.options.openapi_spec_path,
86 endpoint=self._get_openapi_json,
87 )
88
89 def add_openapi_yaml(self):
90 """
91 Adds openapi json to {base_path}/openapi.json
92 (or {base_path}/swagger.json for swagger2)
93 """
94 if not self.options.openapi_spec_path.endswith("json"):
95 return
96
97 openapi_spec_path_yaml = self.options.openapi_spec_path[: -len("json")] + "yaml"
98 logger.debug("Adding spec yaml: %s/%s", self.base_path, openapi_spec_path_yaml)
99 self.router.add_route(
100 methods=["GET"],
101 path=openapi_spec_path_yaml,
102 endpoint=self._get_openapi_yaml,
103 )
104
105 async def _get_openapi_json(self, request):
106 # Yaml parses datetime objects when loading the spec, so we need our custom jsonifier to dump it
107 jsonifier = Jsonifier()
108
109 return StarletteResponse(
110 content=jsonifier.dumps(self._spec_for_prefix(request)),
111 status_code=200,
112 media_type="application/json",
113 )
114
115 async def _get_openapi_yaml(self, request):
116 return StarletteResponse(
117 content=yamldumper(self._spec_for_prefix(request)),
118 status_code=200,
119 media_type="text/yaml",
120 )
121
122 def add_swagger_ui(self):
123 """
124 Adds swagger ui to {base_path}/ui/
125 """
126 console_ui_path = self.options.swagger_ui_path.strip().rstrip("/")
127 logger.debug("Adding swagger-ui: %s%s/", self.base_path, console_ui_path)
128
129 for path in (
130 console_ui_path + "/",
131 console_ui_path + "/index.html",
132 ):
133 self.router.add_route(
134 methods=["GET"], path=path, endpoint=self._get_swagger_ui_home
135 )
136
137 if self.options.swagger_ui_config:
138 self.router.add_route(
139 methods=["GET"],
140 path=console_ui_path + "/swagger-ui-config.json",
141 endpoint=self._get_swagger_ui_config,
142 )
143
144 # we have to add an explicit redirect instead of relying on the
145 # normalize_path_middleware because we also serve static files
146 # from this dir (below)
147
148 async def redirect(request):
149 url = request.scope.get("root_path", "").rstrip("/")
150 url += console_ui_path
151 url += "/"
152 return RedirectResponse(url=url)
153
154 self.router.add_route(methods=["GET"], path=console_ui_path, endpoint=redirect)
155
156 # this route will match and get a permission error when trying to
157 # serve index.html, so we add the redirect above.
158 self.router.mount(
159 path=console_ui_path,
160 app=StaticFiles(directory=str(self.options.swagger_ui_template_dir)),
161 name="swagger_ui_static",
162 )
163
164 async def _get_swagger_ui_home(self, req):
165 base_path = self._base_path_for_prefix(req)
166 template_variables = {
167 "openapi_spec_url": (base_path + self.options.openapi_spec_path),
168 **self.options.swagger_ui_template_arguments,
169 }
170 if self.options.swagger_ui_config:
171 template_variables["configUrl"] = "swagger-ui-config.json"
172
173 return self._templates.TemplateResponse(
174 req, name="index.j2", context=template_variables
175 )
176
177 async def _get_swagger_ui_config(self, request):
178 return StarletteResponse(
179 status_code=200,
180 media_type="application/json",
181 content=json.dumps(self.options.swagger_ui_config),
182 )
183
184
185class SwaggerUIMiddleware(SpecMiddleware):
186 def __init__(self, app: ASGIApp) -> None:
187 """Middleware that hosts a swagger UI.
188
189 :param app: app to wrap in middleware.
190 """
191 self.app = app
192 # Set default to pass unknown routes to next app
193 self.router = Router(default=self.default_fn)
194
195 def add_api(
196 self,
197 specification: Specification,
198 base_path: t.Optional[str] = None,
199 arguments: t.Optional[dict] = None,
200 **kwargs
201 ) -> None:
202 """Add an API to the router based on a OpenAPI spec.
203
204 :param specification: OpenAPI spec.
205 :param base_path: Base path where to add this API.
206 :param arguments: Jinja arguments to replace in the spec.
207 """
208 api = SwaggerUIAPI(
209 specification,
210 base_path=base_path,
211 arguments=arguments,
212 default=self.default_fn,
213 **kwargs
214 )
215 self.router.mount(api.base_path, app=api.router)
216
217 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
218 if scope["type"] != "http":
219 await self.app(scope, receive, send)
220 return
221
222 _original_scope.set(scope.copy()) # type: ignore
223 await self.router(scope, receive, send)
224
225 async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None:
226 """
227 Callback to call next app as default when no matching route is found.
228
229 Unfortunately we cannot just pass the next app as default, since the router manipulates
230 the scope when descending into mounts, losing information about the base path. Therefore,
231 we use the original scope instead.
232
233 This is caused by https://github.com/encode/starlette/issues/1336.
234 """
235 original_scope = _original_scope.get()
236 await self.app(original_scope, receive, send)