1""" 
    2Validation Middleware. 
    3""" 
    4import logging 
    5import typing as t 
    6 
    7from starlette.types import ASGIApp, Receive, Scope, Send 
    8 
    9from connexion import utils 
    10from connexion.datastructures import MediaTypeDict 
    11from connexion.exceptions import NonConformingResponseHeaders 
    12from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware 
    13from connexion.operations import AbstractOperation 
    14from connexion.validators import VALIDATOR_MAP 
    15 
    16logger = logging.getLogger("connexion.middleware.validation") 
    17 
    18 
    19class ResponseValidationOperation: 
    20    def __init__( 
    21        self, 
    22        next_app: ASGIApp, 
    23        *, 
    24        operation: AbstractOperation, 
    25        validator_map: t.Optional[dict] = None, 
    26    ) -> None: 
    27        self.next_app = next_app 
    28        self._operation = operation 
    29        self._validator_map = VALIDATOR_MAP.copy() 
    30        self._validator_map.update(validator_map or {}) 
    31 
    32    def extract_content_type( 
    33        self, headers: t.List[t.Tuple[bytes, bytes]] 
    34    ) -> t.Tuple[str, str]: 
    35        """Extract the mime type and encoding from the content type headers. 
    36 
    37        :param headers: Headers from ASGI scope 
    38 
    39        :return: A tuple of mime type, encoding 
    40        """ 
    41        content_type = utils.extract_content_type(headers) 
    42        mime_type, encoding = utils.split_content_type(content_type) 
    43        if mime_type is None: 
    44            # Content-type header is not required. Take a best guess. 
    45            try: 
    46                mime_type = self._operation.produces[0] 
    47            except IndexError: 
    48                mime_type = "application/octet-stream" 
    49        if encoding is None: 
    50            encoding = "utf-8" 
    51 
    52        return mime_type, encoding 
    53 
    54    def validate_mime_type(self, mime_type: str) -> None: 
    55        """Validate the mime type against the spec if it defines which mime types are produced. 
    56 
    57        :param mime_type: mime type from content type header 
    58        """ 
    59        if not self._operation.produces: 
    60            return 
    61 
    62        media_type_dict = MediaTypeDict( 
    63            [(p.lower(), None) for p in self._operation.produces] 
    64        ) 
    65        if mime_type.lower() not in media_type_dict: 
    66            raise NonConformingResponseHeaders( 
    67                detail=f"Invalid Response Content-type ({mime_type}), " 
    68                f"expected {self._operation.produces}", 
    69            ) 
    70 
    71    @staticmethod 
    72    def validate_required_headers( 
    73        headers: t.List[tuple], response_definition: dict 
    74    ) -> None: 
    75        required_header_keys = { 
    76            k.lower() 
    77            for (k, v) in response_definition.get("headers", {}).items() 
    78            if v.get("required", False) 
    79        } 
    80        header_keys = set(header[0].decode("latin-1").lower() for header in headers) 
    81        missing_keys = required_header_keys - header_keys 
    82        if missing_keys: 
    83            pretty_list = ", ".join(missing_keys) 
    84            msg = ( 
    85                "Keys in response header don't match response specification. Difference: {}" 
    86            ).format(pretty_list) 
    87            raise NonConformingResponseHeaders(detail=msg) 
    88 
    89    async def __call__(self, scope: Scope, receive: Receive, send: Send): 
    90        async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: 
    91            nonlocal send 
    92 
    93            if message["type"] == "http.response.start": 
    94                headers = message["headers"] 
    95 
    96                mime_type, encoding = self.extract_content_type(headers) 
    97                if message["status"] < 400: 
    98                    self.validate_mime_type(mime_type) 
    99 
    100                status = str(message["status"]) 
    101                response_definition = self._operation.response_definition( 
    102                    status, mime_type 
    103                ) 
    104                self.validate_required_headers(headers, response_definition) 
    105 
    106                # Validate body 
    107                try: 
    108                    body_validator = self._validator_map["response"][mime_type]  # type: ignore 
    109                except KeyError: 
    110                    logger.info( 
    111                        f"Skipping validation. No validator registered for content type: " 
    112                        f"{mime_type}." 
    113                    ) 
    114                else: 
    115                    validator = body_validator( 
    116                        scope, 
    117                        schema=self._operation.response_schema(status, mime_type), 
    118                        nullable=utils.is_nullable( 
    119                            self._operation.response_definition(status, mime_type) 
    120                        ), 
    121                        encoding=encoding, 
    122                    ) 
    123                    send = validator.wrap_send(send) 
    124 
    125            return await send(message) 
    126 
    127        await self.next_app(scope, receive, wrapped_send) 
    128 
    129 
    130class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): 
    131    """Validation API.""" 
    132 
    133    def __init__( 
    134        self, 
    135        *args, 
    136        validator_map=None, 
    137        validate_responses=False, 
    138        **kwargs, 
    139    ): 
    140        super().__init__(*args, **kwargs) 
    141        self.validator_map = validator_map 
    142        self.validate_responses = validate_responses 
    143        self.add_paths() 
    144 
    145    def make_operation( 
    146        self, operation: AbstractOperation 
    147    ) -> ResponseValidationOperation: 
    148        if self.validate_responses: 
    149            return ResponseValidationOperation( 
    150                self.next_app, 
    151                operation=operation, 
    152                validator_map=self.validator_map, 
    153            ) 
    154        else: 
    155            return self.next_app  # type: ignore 
    156 
    157 
    158class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]): 
    159    """Middleware for validating requests according to the API contract.""" 
    160 
    161    api_cls = ResponseValidationAPI