1from __future__ import annotations
2
3import typing
4from dataclasses import dataclass, field
5from enum import Enum
6from tempfile import SpooledTemporaryFile
7from urllib.parse import unquote_plus
8
9from starlette.datastructures import FormData, Headers, UploadFile
10
11try:
12 import multipart
13 from multipart.multipart import parse_options_header
14except ModuleNotFoundError: # pragma: nocover
15 parse_options_header = None
16 multipart = None
17
18
19class FormMessage(Enum):
20 FIELD_START = 1
21 FIELD_NAME = 2
22 FIELD_DATA = 3
23 FIELD_END = 4
24 END = 5
25
26
27@dataclass
28class MultipartPart:
29 content_disposition: bytes | None = None
30 field_name: str = ""
31 data: bytes = b""
32 file: UploadFile | None = None
33 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
34
35
36def _user_safe_decode(src: bytes, codec: str) -> str:
37 try:
38 return src.decode(codec)
39 except (UnicodeDecodeError, LookupError):
40 return src.decode("latin-1")
41
42
43class MultiPartException(Exception):
44 def __init__(self, message: str) -> None:
45 self.message = message
46
47
48class FormParser:
49 def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None:
50 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
51 self.headers = headers
52 self.stream = stream
53 self.messages: list[tuple[FormMessage, bytes]] = []
54
55 def on_field_start(self) -> None:
56 message = (FormMessage.FIELD_START, b"")
57 self.messages.append(message)
58
59 def on_field_name(self, data: bytes, start: int, end: int) -> None:
60 message = (FormMessage.FIELD_NAME, data[start:end])
61 self.messages.append(message)
62
63 def on_field_data(self, data: bytes, start: int, end: int) -> None:
64 message = (FormMessage.FIELD_DATA, data[start:end])
65 self.messages.append(message)
66
67 def on_field_end(self) -> None:
68 message = (FormMessage.FIELD_END, b"")
69 self.messages.append(message)
70
71 def on_end(self) -> None:
72 message = (FormMessage.END, b"")
73 self.messages.append(message)
74
75 async def parse(self) -> FormData:
76 # Callbacks dictionary.
77 callbacks = {
78 "on_field_start": self.on_field_start,
79 "on_field_name": self.on_field_name,
80 "on_field_data": self.on_field_data,
81 "on_field_end": self.on_field_end,
82 "on_end": self.on_end,
83 }
84
85 # Create the parser.
86 parser = multipart.QuerystringParser(callbacks)
87 field_name = b""
88 field_value = b""
89
90 items: list[tuple[str, str | UploadFile]] = []
91
92 # Feed the parser with data from the request.
93 async for chunk in self.stream:
94 if chunk:
95 parser.write(chunk)
96 else:
97 parser.finalize()
98 messages = list(self.messages)
99 self.messages.clear()
100 for message_type, message_bytes in messages:
101 if message_type == FormMessage.FIELD_START:
102 field_name = b""
103 field_value = b""
104 elif message_type == FormMessage.FIELD_NAME:
105 field_name += message_bytes
106 elif message_type == FormMessage.FIELD_DATA:
107 field_value += message_bytes
108 elif message_type == FormMessage.FIELD_END:
109 name = unquote_plus(field_name.decode("latin-1"))
110 value = unquote_plus(field_value.decode("latin-1"))
111 items.append((name, value))
112
113 return FormData(items)
114
115
116class MultiPartParser:
117 max_file_size = 1024 * 1024
118
119 def __init__(
120 self,
121 headers: Headers,
122 stream: typing.AsyncGenerator[bytes, None],
123 *,
124 max_files: int | float = 1000,
125 max_fields: int | float = 1000,
126 ) -> None:
127 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
128 self.headers = headers
129 self.stream = stream
130 self.max_files = max_files
131 self.max_fields = max_fields
132 self.items: list[tuple[str, str | UploadFile]] = []
133 self._current_files = 0
134 self._current_fields = 0
135 self._current_partial_header_name: bytes = b""
136 self._current_partial_header_value: bytes = b""
137 self._current_part = MultipartPart()
138 self._charset = ""
139 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
140 self._file_parts_to_finish: list[MultipartPart] = []
141 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
142
143 def on_part_begin(self) -> None:
144 self._current_part = MultipartPart()
145
146 def on_part_data(self, data: bytes, start: int, end: int) -> None:
147 message_bytes = data[start:end]
148 if self._current_part.file is None:
149 self._current_part.data += message_bytes
150 else:
151 self._file_parts_to_write.append((self._current_part, message_bytes))
152
153 def on_part_end(self) -> None:
154 if self._current_part.file is None:
155 self.items.append(
156 (
157 self._current_part.field_name,
158 _user_safe_decode(self._current_part.data, self._charset),
159 )
160 )
161 else:
162 self._file_parts_to_finish.append(self._current_part)
163 # The file can be added to the items right now even though it's not
164 # finished yet, because it will be finished in the `parse()` method, before
165 # self.items is used in the return value.
166 self.items.append((self._current_part.field_name, self._current_part.file))
167
168 def on_header_field(self, data: bytes, start: int, end: int) -> None:
169 self._current_partial_header_name += data[start:end]
170
171 def on_header_value(self, data: bytes, start: int, end: int) -> None:
172 self._current_partial_header_value += data[start:end]
173
174 def on_header_end(self) -> None:
175 field = self._current_partial_header_name.lower()
176 if field == b"content-disposition":
177 self._current_part.content_disposition = self._current_partial_header_value
178 self._current_part.item_headers.append((field, self._current_partial_header_value))
179 self._current_partial_header_name = b""
180 self._current_partial_header_value = b""
181
182 def on_headers_finished(self) -> None:
183 disposition, options = parse_options_header(self._current_part.content_disposition)
184 try:
185 self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
186 except KeyError:
187 raise MultiPartException('The Content-Disposition header field "name" must be provided.')
188 if b"filename" in options:
189 self._current_files += 1
190 if self._current_files > self.max_files:
191 raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
192 filename = _user_safe_decode(options[b"filename"], self._charset)
193 tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
194 self._files_to_close_on_error.append(tempfile)
195 self._current_part.file = UploadFile(
196 file=tempfile, # type: ignore[arg-type]
197 size=0,
198 filename=filename,
199 headers=Headers(raw=self._current_part.item_headers),
200 )
201 else:
202 self._current_fields += 1
203 if self._current_fields > self.max_fields:
204 raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
205 self._current_part.file = None
206
207 def on_end(self) -> None:
208 pass
209
210 async def parse(self) -> FormData:
211 # Parse the Content-Type header to get the multipart boundary.
212 _, params = parse_options_header(self.headers["Content-Type"])
213 charset = params.get(b"charset", "utf-8")
214 if isinstance(charset, bytes):
215 charset = charset.decode("latin-1")
216 self._charset = charset
217 try:
218 boundary = params[b"boundary"]
219 except KeyError:
220 raise MultiPartException("Missing boundary in multipart.")
221
222 # Callbacks dictionary.
223 callbacks = {
224 "on_part_begin": self.on_part_begin,
225 "on_part_data": self.on_part_data,
226 "on_part_end": self.on_part_end,
227 "on_header_field": self.on_header_field,
228 "on_header_value": self.on_header_value,
229 "on_header_end": self.on_header_end,
230 "on_headers_finished": self.on_headers_finished,
231 "on_end": self.on_end,
232 }
233
234 # Create the parser.
235 parser = multipart.MultipartParser(boundary, callbacks)
236 try:
237 # Feed the parser with data from the request.
238 async for chunk in self.stream:
239 parser.write(chunk)
240 # Write file data, it needs to use await with the UploadFile methods
241 # that call the corresponding file methods *in a threadpool*,
242 # otherwise, if they were called directly in the callback methods above
243 # (regular, non-async functions), that would block the event loop in
244 # the main thread.
245 for part, data in self._file_parts_to_write:
246 assert part.file # for type checkers
247 await part.file.write(data)
248 for part in self._file_parts_to_finish:
249 assert part.file # for type checkers
250 await part.file.seek(0)
251 self._file_parts_to_write.clear()
252 self._file_parts_to_finish.clear()
253 except MultiPartException as exc:
254 # Close all the files if there was an error.
255 for file in self._files_to_close_on_error:
256 file.close()
257 raise exc
258
259 parser.finalize()
260 return FormData(self.items)