1from __future__ import annotations
2
3import typing
4from dataclasses import dataclass, field
5from enum import Enum
6from tempfile import SpooledTemporaryFile
7from urllib.parse import unquote_plus
8
9from starlette.datastructures import FormData, Headers, UploadFile
10
11try:
12 import multipart
13 from multipart.multipart import parse_options_header
14except ModuleNotFoundError: # pragma: nocover
15 parse_options_header = None
16 multipart = None
17
18
19class FormMessage(Enum):
20 FIELD_START = 1
21 FIELD_NAME = 2
22 FIELD_DATA = 3
23 FIELD_END = 4
24 END = 5
25
26
27@dataclass
28class MultipartPart:
29 content_disposition: bytes | None = None
30 field_name: str = ""
31 data: bytes = b""
32 file: UploadFile | None = None
33 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
34
35
36def _user_safe_decode(src: bytes, codec: str) -> str:
37 try:
38 return src.decode(codec)
39 except (UnicodeDecodeError, LookupError):
40 return src.decode("latin-1")
41
42
43class MultiPartException(Exception):
44 def __init__(self, message: str) -> None:
45 self.message = message
46
47
48class FormParser:
49 def __init__(
50 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
51 ) -> None:
52 assert (
53 multipart is not None
54 ), "The `python-multipart` library must be installed to use form parsing."
55 self.headers = headers
56 self.stream = stream
57 self.messages: list[tuple[FormMessage, bytes]] = []
58
59 def on_field_start(self) -> None:
60 message = (FormMessage.FIELD_START, b"")
61 self.messages.append(message)
62
63 def on_field_name(self, data: bytes, start: int, end: int) -> None:
64 message = (FormMessage.FIELD_NAME, data[start:end])
65 self.messages.append(message)
66
67 def on_field_data(self, data: bytes, start: int, end: int) -> None:
68 message = (FormMessage.FIELD_DATA, data[start:end])
69 self.messages.append(message)
70
71 def on_field_end(self) -> None:
72 message = (FormMessage.FIELD_END, b"")
73 self.messages.append(message)
74
75 def on_end(self) -> None:
76 message = (FormMessage.END, b"")
77 self.messages.append(message)
78
79 async def parse(self) -> FormData:
80 # Callbacks dictionary.
81 callbacks = {
82 "on_field_start": self.on_field_start,
83 "on_field_name": self.on_field_name,
84 "on_field_data": self.on_field_data,
85 "on_field_end": self.on_field_end,
86 "on_end": self.on_end,
87 }
88
89 # Create the parser.
90 parser = multipart.QuerystringParser(callbacks)
91 field_name = b""
92 field_value = b""
93
94 items: list[tuple[str, str | UploadFile]] = []
95
96 # Feed the parser with data from the request.
97 async for chunk in self.stream:
98 if chunk:
99 parser.write(chunk)
100 else:
101 parser.finalize()
102 messages = list(self.messages)
103 self.messages.clear()
104 for message_type, message_bytes in messages:
105 if message_type == FormMessage.FIELD_START:
106 field_name = b""
107 field_value = b""
108 elif message_type == FormMessage.FIELD_NAME:
109 field_name += message_bytes
110 elif message_type == FormMessage.FIELD_DATA:
111 field_value += message_bytes
112 elif message_type == FormMessage.FIELD_END:
113 name = unquote_plus(field_name.decode("latin-1"))
114 value = unquote_plus(field_value.decode("latin-1"))
115 items.append((name, value))
116
117 return FormData(items)
118
119
120class MultiPartParser:
121 max_file_size = 1024 * 1024
122
123 def __init__(
124 self,
125 headers: Headers,
126 stream: typing.AsyncGenerator[bytes, None],
127 *,
128 max_files: int | float = 1000,
129 max_fields: int | float = 1000,
130 ) -> None:
131 assert (
132 multipart is not None
133 ), "The `python-multipart` library must be installed to use form parsing."
134 self.headers = headers
135 self.stream = stream
136 self.max_files = max_files
137 self.max_fields = max_fields
138 self.items: list[tuple[str, str | UploadFile]] = []
139 self._current_files = 0
140 self._current_fields = 0
141 self._current_partial_header_name: bytes = b""
142 self._current_partial_header_value: bytes = b""
143 self._current_part = MultipartPart()
144 self._charset = ""
145 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
146 self._file_parts_to_finish: list[MultipartPart] = []
147 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
148
149 def on_part_begin(self) -> None:
150 self._current_part = MultipartPart()
151
152 def on_part_data(self, data: bytes, start: int, end: int) -> None:
153 message_bytes = data[start:end]
154 if self._current_part.file is None:
155 self._current_part.data += message_bytes
156 else:
157 self._file_parts_to_write.append((self._current_part, message_bytes))
158
159 def on_part_end(self) -> None:
160 if self._current_part.file is None:
161 self.items.append(
162 (
163 self._current_part.field_name,
164 _user_safe_decode(self._current_part.data, self._charset),
165 )
166 )
167 else:
168 self._file_parts_to_finish.append(self._current_part)
169 # The file can be added to the items right now even though it's not
170 # finished yet, because it will be finished in the `parse()` method, before
171 # self.items is used in the return value.
172 self.items.append((self._current_part.field_name, self._current_part.file))
173
174 def on_header_field(self, data: bytes, start: int, end: int) -> None:
175 self._current_partial_header_name += data[start:end]
176
177 def on_header_value(self, data: bytes, start: int, end: int) -> None:
178 self._current_partial_header_value += data[start:end]
179
180 def on_header_end(self) -> None:
181 field = self._current_partial_header_name.lower()
182 if field == b"content-disposition":
183 self._current_part.content_disposition = self._current_partial_header_value
184 self._current_part.item_headers.append(
185 (field, self._current_partial_header_value)
186 )
187 self._current_partial_header_name = b""
188 self._current_partial_header_value = b""
189
190 def on_headers_finished(self) -> None:
191 disposition, options = parse_options_header(
192 self._current_part.content_disposition
193 )
194 try:
195 self._current_part.field_name = _user_safe_decode(
196 options[b"name"], self._charset
197 )
198 except KeyError:
199 raise MultiPartException(
200 'The Content-Disposition header field "name" must be ' "provided."
201 )
202 if b"filename" in options:
203 self._current_files += 1
204 if self._current_files > self.max_files:
205 raise MultiPartException(
206 f"Too many files. Maximum number of files is {self.max_files}."
207 )
208 filename = _user_safe_decode(options[b"filename"], self._charset)
209 tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
210 self._files_to_close_on_error.append(tempfile)
211 self._current_part.file = UploadFile(
212 file=tempfile, # type: ignore[arg-type]
213 size=0,
214 filename=filename,
215 headers=Headers(raw=self._current_part.item_headers),
216 )
217 else:
218 self._current_fields += 1
219 if self._current_fields > self.max_fields:
220 raise MultiPartException(
221 f"Too many fields. Maximum number of fields is {self.max_fields}."
222 )
223 self._current_part.file = None
224
225 def on_end(self) -> None:
226 pass
227
228 async def parse(self) -> FormData:
229 # Parse the Content-Type header to get the multipart boundary.
230 _, params = parse_options_header(self.headers["Content-Type"])
231 charset = params.get(b"charset", "utf-8")
232 if isinstance(charset, bytes):
233 charset = charset.decode("latin-1")
234 self._charset = charset
235 try:
236 boundary = params[b"boundary"]
237 except KeyError:
238 raise MultiPartException("Missing boundary in multipart.")
239
240 # Callbacks dictionary.
241 callbacks = {
242 "on_part_begin": self.on_part_begin,
243 "on_part_data": self.on_part_data,
244 "on_part_end": self.on_part_end,
245 "on_header_field": self.on_header_field,
246 "on_header_value": self.on_header_value,
247 "on_header_end": self.on_header_end,
248 "on_headers_finished": self.on_headers_finished,
249 "on_end": self.on_end,
250 }
251
252 # Create the parser.
253 parser = multipart.MultipartParser(boundary, callbacks)
254 try:
255 # Feed the parser with data from the request.
256 async for chunk in self.stream:
257 parser.write(chunk)
258 # Write file data, it needs to use await with the UploadFile methods
259 # that call the corresponding file methods *in a threadpool*,
260 # otherwise, if they were called directly in the callback methods above
261 # (regular, non-async functions), that would block the event loop in
262 # the main thread.
263 for part, data in self._file_parts_to_write:
264 assert part.file # for type checkers
265 await part.file.write(data)
266 for part in self._file_parts_to_finish:
267 assert part.file # for type checkers
268 await part.file.seek(0)
269 self._file_parts_to_write.clear()
270 self._file_parts_to_finish.clear()
271 except MultiPartException as exc:
272 # Close all the files if there was an error.
273 for file in self._files_to_close_on_error:
274 file.close()
275 raise exc
276
277 parser.finalize()
278 return FormData(self.items)