Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/connexion/middleware/response_validation.py: 31%
71 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1"""
2Validation Middleware.
3"""
4import logging
5import typing as t
7from starlette.types import ASGIApp, Receive, Scope, Send
9from connexion import utils
10from connexion.datastructures import MediaTypeDict
11from connexion.exceptions import NonConformingResponseHeaders
12from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
13from connexion.operations import AbstractOperation
14from connexion.validators import VALIDATOR_MAP
16logger = logging.getLogger("connexion.middleware.validation")
19class ResponseValidationOperation:
20 def __init__(
21 self,
22 next_app: ASGIApp,
23 *,
24 operation: AbstractOperation,
25 validator_map: t.Optional[dict] = None,
26 ) -> None:
27 self.next_app = next_app
28 self._operation = operation
29 self._validator_map = VALIDATOR_MAP.copy()
30 self._validator_map.update(validator_map or {})
32 def extract_content_type(
33 self, headers: t.List[t.Tuple[bytes, bytes]]
34 ) -> t.Tuple[str, str]:
35 """Extract the mime type and encoding from the content type headers.
37 :param headers: Headers from ASGI scope
39 :return: A tuple of mime type, encoding
40 """
41 mime_type, encoding = utils.extract_content_type(headers)
42 if mime_type is None:
43 # Content-type header is not required. Take a best guess.
44 try:
45 mime_type = self._operation.produces[0]
46 except IndexError:
47 mime_type = "application/octet-stream"
48 if encoding is None:
49 encoding = "utf-8"
51 return mime_type, encoding
53 def validate_mime_type(self, mime_type: str) -> None:
54 """Validate the mime type against the spec if it defines which mime types are produced.
56 :param mime_type: mime type from content type header
57 """
58 if not self._operation.produces:
59 return
61 media_type_dict = MediaTypeDict(
62 [(p.lower(), None) for p in self._operation.produces]
63 )
64 if mime_type.lower() not in media_type_dict:
65 raise NonConformingResponseHeaders(
66 detail=f"Invalid Response Content-type ({mime_type}), "
67 f"expected {self._operation.produces}",
68 )
70 @staticmethod
71 def validate_required_headers(
72 headers: t.List[tuple], response_definition: dict
73 ) -> None:
74 required_header_keys = {
75 k.lower()
76 for (k, v) in response_definition.get("headers", {}).items()
77 if v.get("required", False)
78 }
79 header_keys = set(header[0].decode("latin-1").lower() for header in headers)
80 missing_keys = required_header_keys - header_keys
81 if missing_keys:
82 pretty_list = ", ".join(missing_keys)
83 msg = (
84 "Keys in response header don't match response specification. Difference: {}"
85 ).format(pretty_list)
86 raise NonConformingResponseHeaders(detail=msg)
88 async def __call__(self, scope: Scope, receive: Receive, send: Send):
89 async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None:
90 nonlocal send
92 if message["type"] == "http.response.start":
93 headers = message["headers"]
95 mime_type, encoding = self.extract_content_type(headers)
96 if message["status"] < 400:
97 self.validate_mime_type(mime_type)
99 status = str(message["status"])
100 response_definition = self._operation.response_definition(
101 status, mime_type
102 )
103 self.validate_required_headers(headers, response_definition)
105 # Validate body
106 try:
107 body_validator = self._validator_map["response"][mime_type] # type: ignore
108 except KeyError:
109 logging.info(
110 f"Skipping validation. No validator registered for content type: "
111 f"{mime_type}."
112 )
113 else:
114 validator = body_validator(
115 scope,
116 schema=self._operation.response_schema(status, mime_type),
117 nullable=utils.is_nullable(
118 self._operation.response_definition(status, mime_type)
119 ),
120 encoding=encoding,
121 )
122 send = validator.wrap_send(send)
124 return await send(message)
126 await self.next_app(scope, receive, wrapped_send)
129class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]):
130 """Validation API."""
132 def __init__(
133 self,
134 *args,
135 validator_map=None,
136 validate_responses=False,
137 **kwargs,
138 ):
139 super().__init__(*args, **kwargs)
140 self.validator_map = validator_map
141 self.validate_responses = validate_responses
142 self.add_paths()
144 def make_operation(
145 self, operation: AbstractOperation
146 ) -> ResponseValidationOperation:
147 if self.validate_responses:
148 return ResponseValidationOperation(
149 self.next_app,
150 operation=operation,
151 validator_map=self.validator_map,
152 )
153 else:
154 return self.next_app # type: ignore
157class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]):
158 """Middleware for validating requests according to the API contract."""
160 api_cls = ResponseValidationAPI