1# -*- coding: utf-8 -*-
2import itertools
3import re
4
5from inspect import isclass, getdoc
6from collections import OrderedDict
7
8from collections.abc import Hashable
9
10from flask import current_app
11
12from . import fields
13from .model import Model, ModelBase, OrderedModel
14from .reqparse import RequestParser
15from .utils import merge, not_none, not_none_sorted
16from ._http import HTTPStatus
17
18from urllib.parse import quote
19
20#: Maps Flask/Werkzeug rooting types to Swagger ones
21PATH_TYPES = {
22 "int": "integer",
23 "float": "number",
24 "string": "string",
25 "default": "string",
26}
27
28#: Maps Python primitives types to Swagger ones
29PY_TYPES = {
30 int: "integer",
31 float: "number",
32 str: "string",
33 bool: "boolean",
34 None: "void",
35}
36
37RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
38
39DEFAULT_RESPONSE_DESCRIPTION = "Success"
40DEFAULT_RESPONSE = {"description": DEFAULT_RESPONSE_DESCRIPTION}
41
42RE_RAISES = re.compile(
43 r"^:raises\s+(?P<name>[\w\d_]+)\s*:\s*(?P<description>.*)$", re.MULTILINE
44)
45
46RE_PARSE_RULE = re.compile(
47 r"""
48 (?P<static>[^<]*) # static rule data
49 <
50 (?:
51 (?P<converter>[a-zA-Z_][a-zA-Z0-9_]*) # converter name
52 (?:\((?P<args>.*?)\))? # converter arguments
53 \: # variable delimiter
54 )?
55 (?P<variable>[a-zA-Z_][a-zA-Z0-9_]*) # variable name
56 >
57 """,
58 re.VERBOSE,
59)
60
61
62def ref(model):
63 """Return a reference to model in definitions"""
64 name = model.name if isinstance(model, ModelBase) else model
65 return {"$ref": "#/definitions/{0}".format(quote(name, safe=""))}
66
67
68def _v(value):
69 """Dereference values (callable)"""
70 return value() if callable(value) else value
71
72
73def extract_path(path):
74 """
75 Transform a Flask/Werkzeug URL pattern in a Swagger one.
76 """
77 return RE_URL.sub(r"{\1}", path)
78
79
80def parse_rule(rule):
81 """
82 Parse a rule and return it as generator. Each iteration yields tuples in the form
83 ``(converter, arguments, variable)``. If the converter is `None` it's a static url part, otherwise it's a dynamic
84 one.
85
86 Note: This originally lived in werkzeug.routing.parse_rule until it was removed in werkzeug 2.2.0.
87 """
88 pos = 0
89 end = len(rule)
90 do_match = RE_PARSE_RULE.match
91 used_names = set()
92 while pos < end:
93 m = do_match(rule, pos)
94 if m is None:
95 break
96 data = m.groupdict()
97 if data["static"]:
98 yield None, None, data["static"]
99 variable = data["variable"]
100 converter = data["converter"] or "default"
101 if variable in used_names:
102 raise ValueError(f"variable name {variable!r} used twice.")
103 used_names.add(variable)
104 yield converter, data["args"] or None, variable
105 pos = m.end()
106 if pos < end:
107 remaining = rule[pos:]
108 if ">" in remaining or "<" in remaining:
109 raise ValueError(f"malformed url rule: {rule!r}")
110 yield None, None, remaining
111
112
113def extract_path_params(path):
114 """
115 Extract Flask-style parameters from an URL pattern as Swagger ones.
116 """
117 params = OrderedDict()
118 for converter, arguments, variable in parse_rule(path):
119 if not converter:
120 continue
121 param = {"name": variable, "in": "path", "required": True}
122
123 if converter in PATH_TYPES:
124 param["type"] = PATH_TYPES[converter]
125 elif converter in current_app.url_map.converters:
126 param["type"] = "string"
127 else:
128 raise ValueError("Unsupported type converter: %s" % converter)
129 params[variable] = param
130 return params
131
132
133def _param_to_header(param):
134 param.pop("in", None)
135 param.pop("name", None)
136 return _clean_header(param)
137
138
139def _clean_header(header):
140 if isinstance(header, str):
141 header = {"description": header}
142 typedef = header.get("type", "string")
143 if isinstance(typedef, Hashable) and typedef in PY_TYPES:
144 header["type"] = PY_TYPES[typedef]
145 elif (
146 isinstance(typedef, (list, tuple))
147 and len(typedef) == 1
148 and typedef[0] in PY_TYPES
149 ):
150 header["type"] = "array"
151 header["items"] = {"type": PY_TYPES[typedef[0]]}
152 elif hasattr(typedef, "__schema__"):
153 header.update(typedef.__schema__)
154 else:
155 header["type"] = typedef
156 return not_none(header)
157
158
159def parse_docstring(obj):
160 raw = getdoc(obj)
161 summary = raw.strip(" \n").split("\n")[0].split(".")[0] if raw else None
162 raises = {}
163 details = raw.replace(summary, "").lstrip(". \n").strip(" \n") if raw else None
164 for match in RE_RAISES.finditer(raw or ""):
165 raises[match.group("name")] = match.group("description")
166 if details:
167 details = details.replace(match.group(0), "")
168 parsed = {
169 "raw": raw,
170 "summary": summary or None,
171 "details": details or None,
172 "returns": None,
173 "params": [],
174 "raises": raises,
175 }
176 return parsed
177
178
179def is_hidden(resource, route_doc=None):
180 """
181 Determine whether a Resource has been hidden from Swagger documentation
182 i.e. by using Api.doc(False) decorator
183 """
184 if route_doc is False:
185 return True
186 else:
187 return hasattr(resource, "__apidoc__") and resource.__apidoc__ is False
188
189
190def build_request_body_parameters_schema(body_params):
191 """
192 :param body_params: List of JSON schema of body parameters.
193 :type body_params: list of dict, generated from the json body parameters of a request parser
194 :return dict: The Swagger schema representation of the request body
195
196 :Example:
197 {
198 'name': 'payload',
199 'required': True,
200 'in': 'body',
201 'schema': {
202 'type': 'object',
203 'properties': [
204 'parameter1': {
205 'type': 'integer'
206 },
207 'parameter2': {
208 'type': 'string'
209 }
210 ]
211 }
212 }
213 """
214
215 properties = {}
216 for param in body_params:
217 properties[param["name"]] = {"type": param.get("type", "string")}
218
219 return {
220 "name": "payload",
221 "required": True,
222 "in": "body",
223 "schema": {"type": "object", "properties": properties},
224 }
225
226
227class Swagger(object):
228 """
229 A Swagger documentation wrapper for an API instance.
230 """
231
232 def __init__(self, api):
233 self.api = api
234 self._registered_models = {}
235
236 def as_dict(self):
237 """
238 Output the specification as a serializable ``dict``.
239
240 :returns: the full Swagger specification in a serializable format
241 :rtype: dict
242 """
243 basepath = self.api.base_path
244 if len(basepath) > 1 and basepath.endswith("/"):
245 basepath = basepath[:-1]
246 infos = {
247 "title": _v(self.api.title),
248 "version": _v(self.api.version),
249 }
250 if self.api.description:
251 infos["description"] = _v(self.api.description)
252 if self.api.terms_url:
253 infos["termsOfService"] = _v(self.api.terms_url)
254 if self.api.contact and (self.api.contact_email or self.api.contact_url):
255 infos["contact"] = {
256 "name": _v(self.api.contact),
257 "email": _v(self.api.contact_email),
258 "url": _v(self.api.contact_url),
259 }
260 if self.api.license:
261 infos["license"] = {"name": _v(self.api.license)}
262 if self.api.license_url:
263 infos["license"]["url"] = _v(self.api.license_url)
264
265 paths = {}
266 tags = self.extract_tags(self.api)
267
268 # register errors
269 responses = self.register_errors()
270
271 for ns in self.api.namespaces:
272 for resource, urls, route_doc, kwargs in ns.resources:
273 for url in self.api.ns_urls(ns, urls):
274 path = extract_path(url)
275 serialized = self.serialize_resource(
276 ns, resource, url, route_doc=route_doc, **kwargs
277 )
278 paths[path] = serialized
279
280 # register all models if required
281 if current_app.config["RESTX_INCLUDE_ALL_MODELS"]:
282 for m in self.api.models:
283 self.register_model(m)
284
285 # merge in the top-level authorizations
286 for ns in self.api.namespaces:
287 if ns.authorizations:
288 if self.api.authorizations is None:
289 self.api.authorizations = {}
290 self.api.authorizations = merge(
291 self.api.authorizations, ns.authorizations
292 )
293
294 specs = {
295 "swagger": "2.0",
296 "basePath": basepath,
297 "paths": not_none_sorted(paths),
298 "info": infos,
299 "produces": list(self.api.representations.keys()),
300 "consumes": ["application/json"],
301 "securityDefinitions": self.api.authorizations or None,
302 "security": self.security_requirements(self.api.security) or None,
303 "tags": tags,
304 "definitions": self.serialize_definitions() or None,
305 "responses": responses or None,
306 "host": self.get_host(),
307 }
308 return not_none(specs)
309
310 def get_host(self):
311 hostname = current_app.config.get("SERVER_NAME", None) or None
312 if hostname and self.api.blueprint and self.api.blueprint.subdomain:
313 hostname = ".".join((self.api.blueprint.subdomain, hostname))
314 return hostname
315
316 def extract_tags(self, api):
317 tags = []
318 by_name = {}
319 for tag in api.tags:
320 if isinstance(tag, str):
321 tag = {"name": tag}
322 elif isinstance(tag, (list, tuple)):
323 tag = {"name": tag[0], "description": tag[1]}
324 elif isinstance(tag, dict) and "name" in tag:
325 pass
326 else:
327 raise ValueError("Unsupported tag format for {0}".format(tag))
328 tags.append(tag)
329 by_name[tag["name"]] = tag
330 for ns in api.namespaces:
331 # hide namespaces without any Resources
332 if not ns.resources:
333 continue
334 # hide namespaces with all Resources hidden from Swagger documentation
335 if all(is_hidden(r.resource, route_doc=r.route_doc) for r in ns.resources):
336 continue
337 if ns.name not in by_name:
338 tags.append(
339 {"name": ns.name, "description": ns.description}
340 if ns.description
341 else {"name": ns.name}
342 )
343 elif ns.description:
344 by_name[ns.name]["description"] = ns.description
345 return tags
346
347 def extract_resource_doc(self, resource, url, route_doc=None):
348 route_doc = {} if route_doc is None else route_doc
349 if route_doc is False:
350 return False
351 doc = merge(getattr(resource, "__apidoc__", {}), route_doc)
352 if doc is False:
353 return False
354
355 # ensure unique names for multiple routes to the same resource
356 # provides different Swagger operationId's
357 doc["name"] = (
358 "{}_{}".format(resource.__name__, url) if route_doc else resource.__name__
359 )
360
361 params = merge(self.expected_params(doc), doc.get("params", OrderedDict()))
362 params = merge(params, extract_path_params(url))
363 # Track parameters for late deduplication
364 up_params = {(n, p.get("in", "query")): p for n, p in params.items()}
365 need_to_go_down = set()
366 methods = [m.lower() for m in resource.methods or []]
367 for method in methods:
368 method_doc = doc.get(method, OrderedDict())
369 method_impl = getattr(resource, method)
370 if hasattr(method_impl, "im_func"):
371 method_impl = method_impl.im_func
372 elif hasattr(method_impl, "__func__"):
373 method_impl = method_impl.__func__
374 method_doc = merge(
375 method_doc, getattr(method_impl, "__apidoc__", OrderedDict())
376 )
377 if method_doc is not False:
378 method_doc["docstring"] = parse_docstring(method_impl)
379 method_params = self.expected_params(method_doc)
380 method_params = merge(method_params, method_doc.get("params", {}))
381 inherited_params = OrderedDict(
382 (k, v) for k, v in params.items() if k in method_params
383 )
384 method_doc["params"] = merge(inherited_params, method_params)
385 for name, param in method_doc["params"].items():
386 key = (name, param.get("in", "query"))
387 if key in up_params:
388 need_to_go_down.add(key)
389 doc[method] = method_doc
390 # Deduplicate parameters
391 # For each couple (name, in), if a method overrides it,
392 # we need to move the paramter down to each method
393 if need_to_go_down:
394 for method in methods:
395 method_doc = doc.get(method)
396 if not method_doc:
397 continue
398 params = {
399 (n, p.get("in", "query")): p
400 for n, p in (method_doc["params"] or {}).items()
401 }
402 for key in need_to_go_down:
403 if key not in params:
404 method_doc["params"][key[0]] = up_params[key]
405 doc["params"] = OrderedDict(
406 (k[0], p) for k, p in up_params.items() if k not in need_to_go_down
407 )
408 return doc
409
410 def expected_params(self, doc):
411 params = OrderedDict()
412 if "expect" not in doc:
413 return params
414
415 for expect in doc.get("expect", []):
416 if isinstance(expect, RequestParser):
417 parser_params = OrderedDict(
418 (p["name"], p) for p in expect.__schema__ if p["in"] != "body"
419 )
420 params.update(parser_params)
421
422 body_params = [p for p in expect.__schema__ if p["in"] == "body"]
423 if body_params:
424 params["payload"] = build_request_body_parameters_schema(
425 body_params
426 )
427 elif isinstance(expect, ModelBase):
428 params["payload"] = not_none(
429 {
430 "name": "payload",
431 "required": True,
432 "in": "body",
433 "schema": self.serialize_schema(expect),
434 }
435 )
436 elif isinstance(expect, (list, tuple)):
437 if len(expect) == 2:
438 # this is (payload, description) shortcut
439 model, description = expect
440 params["payload"] = not_none(
441 {
442 "name": "payload",
443 "required": True,
444 "in": "body",
445 "schema": self.serialize_schema(model),
446 "description": description,
447 }
448 )
449 else:
450 params["payload"] = not_none(
451 {
452 "name": "payload",
453 "required": True,
454 "in": "body",
455 "schema": self.serialize_schema(expect),
456 }
457 )
458 return params
459
460 def register_errors(self):
461 responses = {}
462 for exception, handler in self.api.error_handlers.items():
463 doc = parse_docstring(handler)
464 response = {"description": doc["summary"]}
465 apidoc = getattr(handler, "__apidoc__", {})
466 self.process_headers(response, apidoc)
467 if "responses" in apidoc:
468 _, model, _ = list(apidoc["responses"].values())[0]
469 response["schema"] = self.serialize_schema(model)
470 responses[exception.__name__] = not_none(response)
471 return responses
472
473 def serialize_resource(self, ns, resource, url, route_doc=None, **kwargs):
474 doc = self.extract_resource_doc(resource, url, route_doc=route_doc)
475 if doc is False:
476 return
477 path = {"parameters": self.parameters_for(doc) or None}
478 for method in [m.lower() for m in resource.methods or []]:
479 methods = [m.lower() for m in kwargs.get("methods", [])]
480 if doc[method] is False or methods and method not in methods:
481 continue
482 path[method] = self.serialize_operation(doc, method)
483 path[method]["tags"] = [ns.name]
484 return not_none(path)
485
486 def serialize_operation(self, doc, method):
487 operation = {
488 "responses": self.responses_for(doc, method) or None,
489 "summary": doc[method]["docstring"]["summary"],
490 "description": self.description_for(doc, method) or None,
491 "operationId": self.operation_id_for(doc, method),
492 "parameters": self.parameters_for(doc[method]) or None,
493 "security": self.security_for(doc, method),
494 }
495 # Handle 'produces' mimetypes documentation
496 if "produces" in doc[method]:
497 operation["produces"] = doc[method]["produces"]
498 # Handle deprecated annotation
499 if doc.get("deprecated") or doc[method].get("deprecated"):
500 operation["deprecated"] = True
501 # Handle form exceptions:
502 doc_params = list(doc.get("params", {}).values())
503 all_params = doc_params + (operation["parameters"] or [])
504 if all_params and any(p["in"] == "formData" for p in all_params):
505 if any(p["type"] == "file" for p in all_params):
506 operation["consumes"] = ["multipart/form-data"]
507 else:
508 operation["consumes"] = [
509 "application/x-www-form-urlencoded",
510 "multipart/form-data",
511 ]
512 operation.update(self.vendor_fields(doc, method))
513 return not_none(operation)
514
515 def vendor_fields(self, doc, method):
516 """
517 Extract custom 3rd party Vendor fields prefixed with ``x-``
518
519 See: https://swagger.io/specification/#specification-extensions
520 """
521 return dict(
522 (k if k.startswith("x-") else "x-{0}".format(k), v)
523 for k, v in doc[method].get("vendor", {}).items()
524 )
525
526 def description_for(self, doc, method):
527 """Extract the description metadata and fallback on the whole docstring"""
528 parts = []
529 if "description" in doc:
530 parts.append(doc["description"] or "")
531 if method in doc and "description" in doc[method]:
532 parts.append(doc[method]["description"])
533 if doc[method]["docstring"]["details"]:
534 parts.append(doc[method]["docstring"]["details"])
535
536 return "\n".join(parts).strip()
537
538 def operation_id_for(self, doc, method):
539 """Extract the operation id"""
540 return (
541 doc[method]["id"]
542 if "id" in doc[method]
543 else self.api.default_id(doc["name"], method)
544 )
545
546 def parameters_for(self, doc):
547 params = []
548 for name, param in doc["params"].items():
549 param["name"] = name
550 if "type" not in param and "schema" not in param:
551 param["type"] = "string"
552 if "in" not in param:
553 param["in"] = "query"
554
555 if "type" in param and "schema" not in param:
556 ptype = param.get("type", None)
557 if isinstance(ptype, (list, tuple)):
558 typ = ptype[0]
559 param["type"] = "array"
560 param["items"] = {"type": PY_TYPES.get(typ, typ)}
561
562 elif isinstance(ptype, (type, type(None))) and ptype in PY_TYPES:
563 param["type"] = PY_TYPES[ptype]
564
565 params.append(param)
566
567 # Handle fields mask
568 mask = doc.get("__mask__")
569 if mask and current_app.config["RESTX_MASK_SWAGGER"]:
570 param = {
571 "name": current_app.config["RESTX_MASK_HEADER"],
572 "in": "header",
573 "type": "string",
574 "format": "mask",
575 "description": "An optional fields mask",
576 }
577 if isinstance(mask, str):
578 param["default"] = mask
579 params.append(param)
580
581 return params
582
583 def responses_for(self, doc, method):
584 # TODO: simplify/refactor responses/model handling
585 responses = {}
586
587 for d in doc, doc[method]:
588 if "responses" in d:
589 for code, response in d["responses"].items():
590 code = str(code)
591 if isinstance(response, str):
592 description = response
593 model = None
594 kwargs = {}
595 elif len(response) == 3:
596 description, model, kwargs = response
597 elif len(response) == 2:
598 description, model = response
599 kwargs = {}
600 else:
601 raise ValueError("Unsupported response specification")
602 description = description or DEFAULT_RESPONSE_DESCRIPTION
603 if code in responses:
604 responses[code].update(description=description)
605 else:
606 responses[code] = {"description": description}
607 if model:
608 schema = self.serialize_schema(model)
609 envelope = kwargs.get("envelope")
610 if envelope:
611 schema = {"properties": {envelope: schema}}
612 responses[code]["schema"] = schema
613 self.process_headers(
614 responses[code], doc, method, kwargs.get("headers")
615 )
616 if "model" in d:
617 code = str(d.get("default_code", HTTPStatus.OK))
618 if code not in responses:
619 responses[code] = self.process_headers(
620 DEFAULT_RESPONSE.copy(), doc, method
621 )
622 responses[code]["schema"] = self.serialize_schema(d["model"])
623
624 if "docstring" in d:
625 for name, description in d["docstring"]["raises"].items():
626 for exception, handler in self.api.error_handlers.items():
627 error_responses = getattr(handler, "__apidoc__", {}).get(
628 "responses", {}
629 )
630 code = (
631 str(list(error_responses.keys())[0])
632 if error_responses
633 else None
634 )
635 if code and exception.__name__ == name:
636 responses[code] = {"$ref": "#/responses/{0}".format(name)}
637 break
638
639 if not responses:
640 responses[str(HTTPStatus.OK.value)] = self.process_headers(
641 DEFAULT_RESPONSE.copy(), doc, method
642 )
643 return responses
644
645 def process_headers(self, response, doc, method=None, headers=None):
646 method_doc = doc.get(method, {})
647 if "headers" in doc or "headers" in method_doc or headers:
648 response["headers"] = dict(
649 (k, _clean_header(v))
650 for k, v in itertools.chain(
651 doc.get("headers", {}).items(),
652 method_doc.get("headers", {}).items(),
653 (headers or {}).items(),
654 )
655 )
656 return response
657
658 def serialize_definitions(self):
659 return dict(
660 (name, model.__schema__) for name, model in self._registered_models.items()
661 )
662
663 def serialize_schema(self, model):
664 if isinstance(model, (list, tuple)):
665 model = model[0]
666 return {
667 "type": "array",
668 "items": self.serialize_schema(model),
669 }
670
671 elif isinstance(model, ModelBase):
672 self.register_model(model)
673 return ref(model)
674
675 elif isinstance(model, str):
676 self.register_model(model)
677 return ref(model)
678
679 elif isclass(model) and issubclass(model, fields.Raw):
680 return self.serialize_schema(model())
681
682 elif isinstance(model, fields.Raw):
683 return model.__schema__
684
685 elif isinstance(model, (type, type(None))) and model in PY_TYPES:
686 return {"type": PY_TYPES[model]}
687
688 raise ValueError("Model {0} not registered".format(model))
689
690 def register_model(self, model):
691 name = model.name if isinstance(model, ModelBase) else model
692 if name not in self.api.models:
693 raise ValueError("Model {0} not registered".format(name))
694 specs = self.api.models[name]
695 if name in self._registered_models:
696 return ref(model)
697 self._registered_models[name] = specs
698 if isinstance(specs, ModelBase):
699 for parent in specs.__parents__:
700 self.register_model(parent)
701 if isinstance(specs, (Model, OrderedModel)):
702 for field in specs.values():
703 self.register_field(field)
704 return ref(model)
705
706 def register_field(self, field):
707 if isinstance(field, fields.Polymorph):
708 for model in field.mapping.values():
709 self.register_model(model)
710 elif isinstance(field, fields.Nested):
711 self.register_model(field.nested)
712 elif isinstance(field, (fields.List, fields.Wildcard)):
713 self.register_field(field.container)
714
715 def security_for(self, doc, method):
716 security = None
717 if "security" in doc:
718 auth = doc["security"]
719 security = self.security_requirements(auth)
720
721 if "security" in doc[method]:
722 auth = doc[method]["security"]
723 security = self.security_requirements(auth)
724
725 return security
726
727 def security_requirements(self, value):
728 if isinstance(value, (list, tuple)):
729 return [self.security_requirement(v) for v in value]
730 elif value:
731 requirement = self.security_requirement(value)
732 return [requirement] if requirement else None
733 else:
734 return []
735
736 def security_requirement(self, value):
737 if isinstance(value, (str)):
738 return {value: []}
739 elif isinstance(value, dict):
740 return dict(
741 (k, v if isinstance(v, (list, tuple)) else [v])
742 for k, v in value.items()
743 )
744 else:
745 return None