1"""
2Validation Middleware.
3"""
4import logging
5import typing as t
6
7from starlette.types import ASGIApp, Receive, Scope, Send
8
9from connexion import utils
10from connexion.datastructures import MediaTypeDict
11from connexion.exceptions import NonConformingResponseHeaders
12from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
13from connexion.operations import AbstractOperation
14from connexion.validators import VALIDATOR_MAP
15
16logger = logging.getLogger("connexion.middleware.validation")
17
18
19class ResponseValidationOperation:
20 def __init__(
21 self,
22 next_app: ASGIApp,
23 *,
24 operation: AbstractOperation,
25 validator_map: t.Optional[dict] = None,
26 ) -> None:
27 self.next_app = next_app
28 self._operation = operation
29 self._validator_map = VALIDATOR_MAP.copy()
30 self._validator_map.update(validator_map or {})
31
32 def extract_content_type(
33 self, headers: t.List[t.Tuple[bytes, bytes]]
34 ) -> t.Tuple[str, str]:
35 """Extract the mime type and encoding from the content type headers.
36
37 :param headers: Headers from ASGI scope
38
39 :return: A tuple of mime type, encoding
40 """
41 content_type = utils.extract_content_type(headers)
42 mime_type, encoding = utils.split_content_type(content_type)
43 if mime_type is None:
44 # Content-type header is not required. Take a best guess.
45 try:
46 mime_type = self._operation.produces[0]
47 except IndexError:
48 mime_type = "application/octet-stream"
49 if encoding is None:
50 encoding = "utf-8"
51
52 return mime_type, encoding
53
54 def validate_mime_type(self, mime_type: str) -> None:
55 """Validate the mime type against the spec if it defines which mime types are produced.
56
57 :param mime_type: mime type from content type header
58 """
59 if not self._operation.produces:
60 return
61
62 media_type_dict = MediaTypeDict(
63 [(p.lower(), None) for p in self._operation.produces]
64 )
65 if mime_type.lower() not in media_type_dict:
66 raise NonConformingResponseHeaders(
67 detail=f"Invalid Response Content-type ({mime_type}), "
68 f"expected {self._operation.produces}",
69 )
70
71 @staticmethod
72 def validate_required_headers(
73 headers: t.List[tuple], response_definition: dict
74 ) -> None:
75 required_header_keys = {
76 k.lower()
77 for (k, v) in response_definition.get("headers", {}).items()
78 if v.get("required", False)
79 }
80 header_keys = set(header[0].decode("latin-1").lower() for header in headers)
81 missing_keys = required_header_keys - header_keys
82 if missing_keys:
83 pretty_list = ", ".join(missing_keys)
84 msg = (
85 "Keys in response header don't match response specification. Difference: {}"
86 ).format(pretty_list)
87 raise NonConformingResponseHeaders(detail=msg)
88
89 async def __call__(self, scope: Scope, receive: Receive, send: Send):
90 async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
91 nonlocal send
92
93 if message["type"] == "http.response.start":
94 headers = message["headers"]
95
96 mime_type, encoding = self.extract_content_type(headers)
97 if message["status"] < 400:
98 self.validate_mime_type(mime_type)
99
100 status = str(message["status"])
101 response_definition = self._operation.response_definition(
102 status, mime_type
103 )
104 self.validate_required_headers(headers, response_definition)
105
106 # Validate body
107 try:
108 body_validator = self._validator_map["response"][mime_type] # type: ignore
109 except KeyError:
110 logging.info(
111 f"Skipping validation. No validator registered for content type: "
112 f"{mime_type}."
113 )
114 else:
115 validator = body_validator(
116 scope,
117 schema=self._operation.response_schema(status, mime_type),
118 nullable=utils.is_nullable(
119 self._operation.response_definition(status, mime_type)
120 ),
121 encoding=encoding,
122 )
123 send = validator.wrap_send(send)
124
125 return await send(message)
126
127 await self.next_app(scope, receive, wrapped_send)
128
129
130class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]):
131 """Validation API."""
132
133 def __init__(
134 self,
135 *args,
136 validator_map=None,
137 validate_responses=False,
138 **kwargs,
139 ):
140 super().__init__(*args, **kwargs)
141 self.validator_map = validator_map
142 self.validate_responses = validate_responses
143 self.add_paths()
144
145 def make_operation(
146 self, operation: AbstractOperation
147 ) -> ResponseValidationOperation:
148 if self.validate_responses:
149 return ResponseValidationOperation(
150 self.next_app,
151 operation=operation,
152 validator_map=self.validator_map,
153 )
154 else:
155 return self.next_app # type: ignore
156
157
158class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]):
159 """Middleware for validating requests according to the API contract."""
160
161 api_cls = ResponseValidationAPI