1"""
2Validation Middleware.
3"""
4import logging
5import typing as t
6
7from starlette.types import ASGIApp, Receive, Scope, Send
8
9from connexion import utils
10from connexion.datastructures import MediaTypeDict
11from connexion.exceptions import UnsupportedMediaTypeProblem
12from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
13from connexion.operations import AbstractOperation
14from connexion.validators import VALIDATOR_MAP
15
16logger = logging.getLogger("connexion.middleware.validation")
17
18
19class RequestValidationOperation:
20 def __init__(
21 self,
22 next_app: ASGIApp,
23 *,
24 operation: AbstractOperation,
25 strict_validation: bool = False,
26 validator_map: t.Optional[dict] = None,
27 ) -> None:
28 self.next_app = next_app
29 self._operation = operation
30 self.strict_validation = strict_validation
31 self._validator_map = VALIDATOR_MAP.copy()
32 self._validator_map.update(validator_map or {})
33
34 def extract_content_type(
35 self, headers: t.List[t.Tuple[bytes, bytes]]
36 ) -> t.Tuple[str, str]:
37 """Extract the mime type and encoding from the content type headers.
38
39 :param headers: Headers from ASGI scope
40
41 :return: A tuple of mime type, encoding
42 """
43 content_type = utils.extract_content_type(headers)
44 mime_type, encoding = utils.split_content_type(content_type)
45 if mime_type is None:
46 # Content-type header is not required. Take a best guess.
47 try:
48 mime_type = self._operation.consumes[0]
49 except IndexError:
50 mime_type = "application/octet-stream"
51 if encoding is None:
52 encoding = "utf-8"
53
54 return mime_type, encoding
55
56 def validate_mime_type(self, mime_type: str) -> None:
57 """Validate the mime type against the spec if it defines which mime types are accepted.
58
59 :param mime_type: mime type from content type header
60 """
61 if not self._operation.consumes:
62 return
63
64 # Convert to MediaTypeDict to handle media-ranges
65 media_type_dict = MediaTypeDict(
66 [(c.lower(), None) for c in self._operation.consumes]
67 )
68 if mime_type.lower() not in media_type_dict:
69 raise UnsupportedMediaTypeProblem(
70 detail=f"Invalid Content-type ({mime_type}), "
71 f"expected {self._operation.consumes}"
72 )
73
74 @property
75 def security_query_params(self) -> t.List[str]:
76 """Get the names of query parameters that are used for security."""
77 if not hasattr(self, "_security_query_params"):
78 security_query_params: t.List[str] = []
79 if self._operation.security is None:
80 self._security_query_params = security_query_params
81 return self._security_query_params
82
83 for security_req in self._operation.security:
84 for scheme_name in security_req:
85 security_scheme = self._operation.security_schemes[scheme_name]
86
87 if (
88 security_scheme["type"] == "apiKey"
89 and security_scheme["in"] == "query"
90 ):
91 # Only query parameters need to be considered for strict_validation
92 security_query_params.append(security_scheme["name"])
93 self._security_query_params = security_query_params
94
95 return self._security_query_params
96
97 async def __call__(self, scope: Scope, receive: Receive, send: Send):
98 # Validate parameters & headers
99 uri_parser_class = self._operation._uri_parser_class
100 uri_parser = uri_parser_class(
101 self._operation.parameters, self._operation.body_definition()
102 )
103 parameter_validator_cls = self._validator_map["parameter"]
104 parameter_validator = parameter_validator_cls( # type: ignore
105 self._operation.parameters,
106 uri_parser=uri_parser,
107 strict_validation=self.strict_validation,
108 security_query_params=self.security_query_params,
109 )
110 parameter_validator.validate(scope)
111
112 # Extract content type
113 headers = scope["headers"]
114 mime_type, encoding = self.extract_content_type(headers)
115 self.validate_mime_type(mime_type)
116
117 # Validate body
118 schema = self._operation.body_schema(mime_type)
119 if schema:
120 try:
121 body_validator = self._validator_map["body"][mime_type] # type: ignore
122 except KeyError:
123 logging.info(
124 f"Skipping validation. No validator registered for content type: "
125 f"{mime_type}."
126 )
127 else:
128 validator = body_validator(
129 schema=schema,
130 required=self._operation.request_body.get("required", False),
131 nullable=utils.is_nullable(
132 self._operation.body_definition(mime_type)
133 ),
134 encoding=encoding,
135 strict_validation=self.strict_validation,
136 uri_parser=self._operation.uri_parser_class(
137 self._operation.parameters, self._operation.body_definition()
138 ),
139 )
140 receive = await validator.wrap_receive(receive, scope=scope)
141
142 await self.next_app(scope, receive, send)
143
144
145class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
146 """Validation API."""
147
148 def __init__(
149 self,
150 *args,
151 strict_validation=False,
152 validator_map=None,
153 uri_parser_class=None,
154 **kwargs,
155 ):
156 super().__init__(*args, **kwargs)
157 self.validator_map = validator_map
158
159 logger.debug("Strict Request Validation: %s", str(strict_validation))
160 self.strict_validation = strict_validation
161
162 self.uri_parser_class = uri_parser_class
163
164 self.add_paths()
165
166 def make_operation(
167 self, operation: AbstractOperation
168 ) -> RequestValidationOperation:
169 return RequestValidationOperation(
170 self.next_app,
171 operation=operation,
172 strict_validation=self.strict_validation,
173 validator_map=self.validator_map,
174 )
175
176
177class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]):
178 """Middleware for validating requests according to the API contract."""
179
180 api_cls = RequestValidationAPI
181
182
183class MissingValidationOperation(Exception):
184 """Missing validation operation"""