Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/connexion/validators/form_data.py: 34%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

67 statements  

1import logging 

2import typing as t 

3 

4from jsonschema import Draft4Validator, ValidationError 

5from starlette.datastructures import Headers, UploadFile 

6from starlette.formparsers import FormParser, MultiPartParser 

7from starlette.types import Scope 

8 

9from connexion.exceptions import BadRequestProblem, ExtraParameterProblem 

10from connexion.json_schema import Draft4RequestValidator, format_error_with_path 

11from connexion.uri_parsing import AbstractURIParser 

12from connexion.validators import AbstractRequestBodyValidator 

13 

14logger = logging.getLogger("connexion.validators.form_data") 

15 

16 

17class FormDataValidator(AbstractRequestBodyValidator): 

18 """Request body validator for form content types.""" 

19 

20 def __init__( 

21 self, 

22 *, 

23 schema: dict, 

24 required=False, 

25 nullable=False, 

26 encoding: str, 

27 strict_validation: bool, 

28 uri_parser: t.Optional[AbstractURIParser] = None, 

29 ) -> None: 

30 super().__init__( 

31 schema=schema, 

32 required=required, 

33 nullable=nullable, 

34 encoding=encoding, 

35 strict_validation=strict_validation, 

36 ) 

37 self._uri_parser = uri_parser 

38 

39 @property 

40 def _validator(self): 

41 return Draft4RequestValidator( 

42 self._schema, format_checker=Draft4Validator.FORMAT_CHECKER 

43 ) 

44 

45 @property 

46 def _form_parser_cls(self): 

47 return FormParser 

48 

49 async def _parse(self, stream: t.AsyncGenerator[bytes, None], scope: Scope) -> dict: 

50 headers = Headers(scope=scope) 

51 form_parser = self._form_parser_cls(headers, stream) 

52 data = await form_parser.parse() 

53 

54 if self._uri_parser is not None: 

55 # Don't parse file_data 

56 form_data = {} 

57 file_data: t.Dict[str, t.Union[str, t.List[str]]] = {} 

58 for key in data.keys(): 

59 # Extract files 

60 param_schema = self._schema.get("properties", {}).get(key, {}) 

61 value = data.getlist(key) 

62 

63 def is_file(schema): 

64 return schema.get("type") == "string" and schema.get("format") in [ 

65 "binary", 

66 "base64", 

67 ] 

68 

69 # Single file upload 

70 if is_file(param_schema): 

71 # Unpack if single file received 

72 if len(value) == 1: 

73 file_data[key] = "" 

74 # If multiple files received, replace with array so validation will fail 

75 else: 

76 file_data[key] = [""] * len(value) 

77 # Multiple file upload, replace files with array of strings 

78 elif is_file(param_schema.get("items", {})): 

79 file_data[key] = [""] * len(value) 

80 # UploadFile received for non-file upload. Replace and let validation handle. 

81 elif isinstance(value[0], UploadFile): 

82 file_data[key] = [""] * len(value) 

83 # No files, add multi-value to form data and let uri parser handle multi-value 

84 else: 

85 form_data[key] = value 

86 

87 # Resolve form data, not file data 

88 data = self._uri_parser.resolve_form(form_data) 

89 # Add the files again 

90 data.update(file_data) 

91 else: 

92 data = {k: data.getlist(k) for k in data} 

93 

94 return data 

95 

96 def _validate(self, body: t.Any) -> t.Optional[dict]: # type: ignore[return] 

97 if not isinstance(body, dict): 

98 raise BadRequestProblem("Parsed body must be a mapping") 

99 if self._strict_validation: 

100 self._validate_params_strictly(body) 

101 try: 

102 self._validator.validate(body) 

103 except ValidationError as exception: 

104 error_path_msg = format_error_with_path(exception=exception) 

105 logger.error( 

106 f"Validation error: {exception.message}{error_path_msg}", 

107 extra={"validator": "body"}, 

108 ) 

109 raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") 

110 

111 def _validate_params_strictly(self, data: dict) -> None: 

112 form_params = data.keys() 

113 spec_params = self._schema.get("properties", {}).keys() 

114 errors = set(form_params).difference(set(spec_params)) 

115 if errors: 

116 raise ExtraParameterProblem(param_type="formData", extra_params=errors) 

117 

118 

119class MultiPartFormDataValidator(FormDataValidator): 

120 @property 

121 def _form_parser_cls(self): 

122 return MultiPartParser