1import logging
2import typing as t
3
4from jsonschema import Draft4Validator, ValidationError
5from starlette.datastructures import Headers, UploadFile
6from starlette.formparsers import FormParser, MultiPartParser
7from starlette.types import Scope
8
9from connexion.exceptions import BadRequestProblem, ExtraParameterProblem
10from connexion.json_schema import Draft4RequestValidator, format_error_with_path
11from connexion.uri_parsing import AbstractURIParser
12from connexion.validators import AbstractRequestBodyValidator
13
14logger = logging.getLogger("connexion.validators.form_data")
15
16
17class FormDataValidator(AbstractRequestBodyValidator):
18 """Request body validator for form content types."""
19
20 def __init__(
21 self,
22 *,
23 schema: dict,
24 required=False,
25 nullable=False,
26 encoding: str,
27 strict_validation: bool,
28 uri_parser: t.Optional[AbstractURIParser] = None,
29 ) -> None:
30 super().__init__(
31 schema=schema,
32 required=required,
33 nullable=nullable,
34 encoding=encoding,
35 strict_validation=strict_validation,
36 )
37 self._uri_parser = uri_parser
38
39 @property
40 def _validator(self):
41 return Draft4RequestValidator(
42 self._schema, format_checker=Draft4Validator.FORMAT_CHECKER
43 )
44
45 @property
46 def _form_parser_cls(self):
47 return FormParser
48
49 async def _parse(self, stream: t.AsyncGenerator[bytes, None], scope: Scope) -> dict:
50 headers = Headers(scope=scope)
51 form_parser = self._form_parser_cls(headers, stream)
52 data = await form_parser.parse()
53
54 if self._uri_parser is not None:
55 # Don't parse file_data
56 form_data = {}
57 file_data: t.Dict[str, t.Union[str, t.List[str]]] = {}
58 for key in data.keys():
59 # Extract files
60 param_schema = self._schema.get("properties", {}).get(key, {})
61 value = data.getlist(key)
62
63 def is_file(schema):
64 return schema.get("type") == "string" and schema.get("format") in [
65 "binary",
66 "base64",
67 ]
68
69 # Single file upload
70 if is_file(param_schema):
71 # Unpack if single file received
72 if len(value) == 1:
73 file_data[key] = ""
74 # If multiple files received, replace with array so validation will fail
75 else:
76 file_data[key] = [""] * len(value)
77 # Multiple file upload, replace files with array of strings
78 elif is_file(param_schema.get("items", {})):
79 file_data[key] = [""] * len(value)
80 # UploadFile received for non-file upload. Replace and let validation handle.
81 elif isinstance(value[0], UploadFile):
82 file_data[key] = [""] * len(value)
83 # No files, add multi-value to form data and let uri parser handle multi-value
84 else:
85 form_data[key] = value
86
87 # Resolve form data, not file data
88 data = self._uri_parser.resolve_form(form_data)
89 # Add the files again
90 data.update(file_data)
91 else:
92 data = {k: data.getlist(k) for k in data}
93
94 return data
95
96 def _validate(self, body: t.Any) -> t.Optional[dict]: # type: ignore[return]
97 if not isinstance(body, dict):
98 raise BadRequestProblem("Parsed body must be a mapping")
99 if self._strict_validation:
100 self._validate_params_strictly(body)
101 try:
102 self._validator.validate(body)
103 except ValidationError as exception:
104 error_path_msg = format_error_with_path(exception=exception)
105 logger.error(
106 f"Validation error: {exception.message}{error_path_msg}",
107 extra={"validator": "body"},
108 )
109 raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}")
110
111 def _validate_params_strictly(self, data: dict) -> None:
112 form_params = data.keys()
113 spec_params = self._schema.get("properties", {}).keys()
114 errors = set(form_params).difference(set(spec_params))
115 if errors:
116 raise ExtraParameterProblem(param_type="formData", extra_params=errors)
117
118
119class MultiPartFormDataValidator(FormDataValidator):
120 @property
121 def _form_parser_cls(self):
122 return MultiPartParser