Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/connexion/middleware/response_validation.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

72 statements  

1""" 

2Validation Middleware. 

3""" 

4import logging 

5import typing as t 

6 

7from starlette.types import ASGIApp, Receive, Scope, Send 

8 

9from connexion import utils 

10from connexion.datastructures import MediaTypeDict 

11from connexion.exceptions import NonConformingResponseHeaders 

12from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware 

13from connexion.operations import AbstractOperation 

14from connexion.validators import VALIDATOR_MAP 

15 

16logger = logging.getLogger("connexion.middleware.validation") 

17 

18 

19class ResponseValidationOperation: 

20 def __init__( 

21 self, 

22 next_app: ASGIApp, 

23 *, 

24 operation: AbstractOperation, 

25 validator_map: t.Optional[dict] = None, 

26 ) -> None: 

27 self.next_app = next_app 

28 self._operation = operation 

29 self._validator_map = VALIDATOR_MAP.copy() 

30 self._validator_map.update(validator_map or {}) 

31 

32 def extract_content_type( 

33 self, headers: t.List[t.Tuple[bytes, bytes]] 

34 ) -> t.Tuple[str, str]: 

35 """Extract the mime type and encoding from the content type headers. 

36 

37 :param headers: Headers from ASGI scope 

38 

39 :return: A tuple of mime type, encoding 

40 """ 

41 content_type = utils.extract_content_type(headers) 

42 mime_type, encoding = utils.split_content_type(content_type) 

43 if mime_type is None: 

44 # Content-type header is not required. Take a best guess. 

45 try: 

46 mime_type = self._operation.produces[0] 

47 except IndexError: 

48 mime_type = "application/octet-stream" 

49 if encoding is None: 

50 encoding = "utf-8" 

51 

52 return mime_type, encoding 

53 

54 def validate_mime_type(self, mime_type: str) -> None: 

55 """Validate the mime type against the spec if it defines which mime types are produced. 

56 

57 :param mime_type: mime type from content type header 

58 """ 

59 if not self._operation.produces: 

60 return 

61 

62 media_type_dict = MediaTypeDict( 

63 [(p.lower(), None) for p in self._operation.produces] 

64 ) 

65 if mime_type.lower() not in media_type_dict: 

66 raise NonConformingResponseHeaders( 

67 detail=f"Invalid Response Content-type ({mime_type}), " 

68 f"expected {self._operation.produces}", 

69 ) 

70 

71 @staticmethod 

72 def validate_required_headers( 

73 headers: t.List[tuple], response_definition: dict 

74 ) -> None: 

75 required_header_keys = { 

76 k.lower() 

77 for (k, v) in response_definition.get("headers", {}).items() 

78 if v.get("required", False) 

79 } 

80 header_keys = set(header[0].decode("latin-1").lower() for header in headers) 

81 missing_keys = required_header_keys - header_keys 

82 if missing_keys: 

83 pretty_list = ", ".join(missing_keys) 

84 msg = ( 

85 "Keys in response header don't match response specification. Difference: {}" 

86 ).format(pretty_list) 

87 raise NonConformingResponseHeaders(detail=msg) 

88 

89 async def __call__(self, scope: Scope, receive: Receive, send: Send): 

90 async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: 

91 nonlocal send 

92 

93 if message["type"] == "http.response.start": 

94 headers = message["headers"] 

95 

96 mime_type, encoding = self.extract_content_type(headers) 

97 if message["status"] < 400: 

98 self.validate_mime_type(mime_type) 

99 

100 status = str(message["status"]) 

101 response_definition = self._operation.response_definition( 

102 status, mime_type 

103 ) 

104 self.validate_required_headers(headers, response_definition) 

105 

106 # Validate body 

107 try: 

108 body_validator = self._validator_map["response"][mime_type] # type: ignore 

109 except KeyError: 

110 logging.info( 

111 f"Skipping validation. No validator registered for content type: " 

112 f"{mime_type}." 

113 ) 

114 else: 

115 validator = body_validator( 

116 scope, 

117 schema=self._operation.response_schema(status, mime_type), 

118 nullable=utils.is_nullable( 

119 self._operation.response_definition(status, mime_type) 

120 ), 

121 encoding=encoding, 

122 ) 

123 send = validator.wrap_send(send) 

124 

125 return await send(message) 

126 

127 await self.next_app(scope, receive, wrapped_send) 

128 

129 

130class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): 

131 """Validation API.""" 

132 

133 def __init__( 

134 self, 

135 *args, 

136 validator_map=None, 

137 validate_responses=False, 

138 **kwargs, 

139 ): 

140 super().__init__(*args, **kwargs) 

141 self.validator_map = validator_map 

142 self.validate_responses = validate_responses 

143 self.add_paths() 

144 

145 def make_operation( 

146 self, operation: AbstractOperation 

147 ) -> ResponseValidationOperation: 

148 if self.validate_responses: 

149 return ResponseValidationOperation( 

150 self.next_app, 

151 operation=operation, 

152 validator_map=self.validator_map, 

153 ) 

154 else: 

155 return self.next_app # type: ignore 

156 

157 

158class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]): 

159 """Middleware for validating requests according to the API contract.""" 

160 

161 api_cls = ResponseValidationAPI