1import json
2import logging
3import re
4import typing as t
5from contextvars import ContextVar
6
7from starlette.requests import Request as StarletteRequest
8from starlette.responses import RedirectResponse
9from starlette.responses import Response as StarletteResponse
10from starlette.routing import Router
11from starlette.staticfiles import StaticFiles
12from starlette.templating import Jinja2Templates
13from starlette.types import ASGIApp, Receive, Scope, Send
14
15from connexion.jsonifier import Jsonifier
16from connexion.middleware import SpecMiddleware
17from connexion.middleware.abstract import AbstractSpecAPI
18from connexion.options import SwaggerUIConfig, SwaggerUIOptions
19from connexion.spec import Specification
20from connexion.utils import yamldumper
21
22logger = logging.getLogger("connexion.middleware.swagger_ui")
23
24
25_original_scope: ContextVar[Scope] = ContextVar("SCOPE")
26
27
28class SwaggerUIAPI(AbstractSpecAPI):
29 def __init__(
30 self,
31 *args,
32 default: ASGIApp,
33 swagger_ui_options: t.Optional[SwaggerUIOptions] = None,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.router = Router(default=default)
39 self.options = SwaggerUIConfig(
40 swagger_ui_options, oas_version=self.specification.version
41 )
42
43 if self.options.openapi_spec_available:
44 self.add_openapi_json()
45 self.add_openapi_yaml()
46
47 if self.options.swagger_ui_available:
48 self.add_swagger_ui()
49
50 self._templates = Jinja2Templates(
51 directory=str(self.options.swagger_ui_template_dir)
52 )
53
54 @staticmethod
55 def normalize_string(string):
56 return re.sub(r"[^a-zA-Z0-9]", "_", string.strip("/"))
57
58 def _base_path_for_prefix(self, request: StarletteRequest) -> str:
59 """
60 returns a modified basePath which includes the incoming root_path.
61 """
62 return request.scope.get(
63 "route_root_path", request.scope.get("root_path", "")
64 ).rstrip("/")
65
66 def _spec_for_prefix(self, request):
67 """
68 returns a spec with a modified basePath / servers block
69 which corresponds to the incoming request path.
70 This is needed when behind a path-altering reverse proxy.
71 """
72 base_path = self._base_path_for_prefix(request)
73 return self.specification.with_base_path(base_path).raw
74
75 def add_openapi_json(self):
76 """
77 Adds openapi json to {base_path}/openapi.json
78 (or {base_path}/swagger.json for swagger2)
79 """
80 logger.info(
81 "Adding spec json: %s%s", self.base_path, self.options.openapi_spec_path
82 )
83 self.router.add_route(
84 methods=["GET"],
85 path=self.options.openapi_spec_path,
86 endpoint=self._get_openapi_json,
87 )
88
89 def add_openapi_yaml(self):
90 """
91 Adds openapi json to {base_path}/openapi.json
92 (or {base_path}/swagger.json for swagger2)
93 """
94 if not self.options.openapi_spec_path.endswith("json"):
95 return
96
97 openapi_spec_path_yaml = self.options.openapi_spec_path[: -len("json")] + "yaml"
98 logger.debug("Adding spec yaml: %s/%s", self.base_path, openapi_spec_path_yaml)
99 self.router.add_route(
100 methods=["GET"],
101 path=openapi_spec_path_yaml,
102 endpoint=self._get_openapi_yaml,
103 )
104
105 async def _get_openapi_json(self, request):
106 # Yaml parses datetime objects when loading the spec, so we need our custom jsonifier to dump it
107 jsonifier = Jsonifier()
108
109 return StarletteResponse(
110 content=jsonifier.dumps(self._spec_for_prefix(request)),
111 status_code=200,
112 media_type="application/json",
113 )
114
115 async def _get_openapi_yaml(self, request):
116 return StarletteResponse(
117 content=yamldumper(self._spec_for_prefix(request)),
118 status_code=200,
119 media_type="text/yaml",
120 )
121
122 def add_swagger_ui(self):
123 """
124 Adds swagger ui to {base_path}/ui/
125 """
126 console_ui_path = self.options.swagger_ui_path.strip().rstrip("/")
127 logger.debug("Adding swagger-ui: %s%s/", self.base_path, console_ui_path)
128
129 for path in (
130 console_ui_path + "/",
131 console_ui_path + "/index.html",
132 ):
133 self.router.add_route(
134 methods=["GET"], path=path, endpoint=self._get_swagger_ui_home
135 )
136
137 if self.options.swagger_ui_config:
138 self.router.add_route(
139 methods=["GET"],
140 path=console_ui_path + "/swagger-ui-config.json",
141 endpoint=self._get_swagger_ui_config,
142 )
143
144 # we have to add an explicit redirect instead of relying on the
145 # normalize_path_middleware because we also serve static files
146 # from this dir (below)
147
148 async def redirect(request):
149 url = request.scope.get("root_path", "").rstrip("/")
150 url += console_ui_path
151 url += "/"
152 return RedirectResponse(url=url)
153
154 self.router.add_route(methods=["GET"], path=console_ui_path, endpoint=redirect)
155
156 # this route will match and get a permission error when trying to
157 # serve index.html, so we add the redirect above.
158 self.router.mount(
159 path=console_ui_path,
160 app=StaticFiles(directory=str(self.options.swagger_ui_template_dir)),
161 name="swagger_ui_static",
162 )
163
164 async def _get_swagger_ui_home(self, req):
165 base_path = self._base_path_for_prefix(req)
166 template_variables = {
167 "request": req,
168 "openapi_spec_url": (base_path + self.options.openapi_spec_path),
169 **self.options.swagger_ui_template_arguments,
170 }
171 if self.options.swagger_ui_config:
172 template_variables["configUrl"] = "swagger-ui-config.json"
173
174 return self._templates.TemplateResponse("index.j2", template_variables)
175
176 async def _get_swagger_ui_config(self, request):
177 return StarletteResponse(
178 status_code=200,
179 media_type="application/json",
180 content=json.dumps(self.options.swagger_ui_config),
181 )
182
183
184class SwaggerUIMiddleware(SpecMiddleware):
185 def __init__(self, app: ASGIApp) -> None:
186 """Middleware that hosts a swagger UI.
187
188 :param app: app to wrap in middleware.
189 """
190 self.app = app
191 # Set default to pass unknown routes to next app
192 self.router = Router(default=self.default_fn)
193
194 def add_api(
195 self,
196 specification: Specification,
197 base_path: t.Optional[str] = None,
198 arguments: t.Optional[dict] = None,
199 **kwargs
200 ) -> None:
201 """Add an API to the router based on a OpenAPI spec.
202
203 :param specification: OpenAPI spec.
204 :param base_path: Base path where to add this API.
205 :param arguments: Jinja arguments to replace in the spec.
206 """
207 api = SwaggerUIAPI(
208 specification,
209 base_path=base_path,
210 arguments=arguments,
211 default=self.default_fn,
212 **kwargs
213 )
214 self.router.mount(api.base_path, app=api.router)
215
216 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
217 if scope["type"] != "http":
218 await self.app(scope, receive, send)
219 return
220
221 _original_scope.set(scope.copy()) # type: ignore
222 await self.router(scope, receive, send)
223
224 async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None:
225 """
226 Callback to call next app as default when no matching route is found.
227
228 Unfortunately we cannot just pass the next app as default, since the router manipulates
229 the scope when descending into mounts, losing information about the base path. Therefore,
230 we use the original scope instead.
231
232 This is caused by https://github.com/encode/starlette/issues/1336.
233 """
234 original_scope = _original_scope.get()
235 await self.app(original_scope, receive, send)