1from __future__ import annotations
2
3from collections.abc import AsyncGenerator
4from dataclasses import dataclass, field
5from enum import Enum
6from tempfile import SpooledTemporaryFile
7from typing import TYPE_CHECKING
8from urllib.parse import unquote_plus
9
10from starlette.datastructures import FormData, Headers, UploadFile
11
12if TYPE_CHECKING:
13 import python_multipart as multipart
14 from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header
15else:
16 try:
17 try:
18 import python_multipart as multipart
19 from python_multipart.multipart import parse_options_header
20 except ModuleNotFoundError: # pragma: no cover
21 import multipart
22 from multipart.multipart import parse_options_header
23 except ModuleNotFoundError: # pragma: no cover
24 multipart = None
25 parse_options_header = None
26
27
28class FormMessage(Enum):
29 FIELD_START = 1
30 FIELD_NAME = 2
31 FIELD_DATA = 3
32 FIELD_END = 4
33 END = 5
34
35
36@dataclass
37class MultipartPart:
38 content_disposition: bytes | None = None
39 field_name: str = ""
40 data: bytearray = field(default_factory=bytearray)
41 file: UploadFile | None = None
42 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
43
44
45def _user_safe_decode(src: bytes | bytearray, codec: str) -> str:
46 try:
47 return src.decode(codec)
48 except (UnicodeDecodeError, LookupError):
49 return src.decode("latin-1")
50
51
52class MultiPartException(Exception):
53 def __init__(self, message: str) -> None:
54 self.message = message
55
56
57class FormParser:
58 def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None:
59 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
60 self.headers = headers
61 self.stream = stream
62 self.messages: list[tuple[FormMessage, bytes]] = []
63
64 def on_field_start(self) -> None:
65 message = (FormMessage.FIELD_START, b"")
66 self.messages.append(message)
67
68 def on_field_name(self, data: bytes, start: int, end: int) -> None:
69 message = (FormMessage.FIELD_NAME, data[start:end])
70 self.messages.append(message)
71
72 def on_field_data(self, data: bytes, start: int, end: int) -> None:
73 message = (FormMessage.FIELD_DATA, data[start:end])
74 self.messages.append(message)
75
76 def on_field_end(self) -> None:
77 message = (FormMessage.FIELD_END, b"")
78 self.messages.append(message)
79
80 def on_end(self) -> None:
81 message = (FormMessage.END, b"")
82 self.messages.append(message)
83
84 async def parse(self) -> FormData:
85 # Callbacks dictionary.
86 callbacks: QuerystringCallbacks = {
87 "on_field_start": self.on_field_start,
88 "on_field_name": self.on_field_name,
89 "on_field_data": self.on_field_data,
90 "on_field_end": self.on_field_end,
91 "on_end": self.on_end,
92 }
93
94 # Create the parser.
95 parser = multipart.QuerystringParser(callbacks)
96 field_name = b""
97 field_value = b""
98
99 items: list[tuple[str, str | UploadFile]] = []
100
101 # Feed the parser with data from the request.
102 async for chunk in self.stream:
103 if chunk:
104 parser.write(chunk)
105 else:
106 parser.finalize()
107 messages = list(self.messages)
108 self.messages.clear()
109 for message_type, message_bytes in messages:
110 if message_type == FormMessage.FIELD_START:
111 field_name = b""
112 field_value = b""
113 elif message_type == FormMessage.FIELD_NAME:
114 field_name += message_bytes
115 elif message_type == FormMessage.FIELD_DATA:
116 field_value += message_bytes
117 elif message_type == FormMessage.FIELD_END:
118 name = unquote_plus(field_name.decode("latin-1"))
119 value = unquote_plus(field_value.decode("latin-1"))
120 items.append((name, value))
121
122 return FormData(items)
123
124
125class MultiPartParser:
126 spool_max_size = 1024 * 1024 # 1MB
127 """The maximum size of the spooled temporary file used to store file data."""
128 max_part_size = 1024 * 1024 # 1MB
129 """The maximum size of a part in the multipart request."""
130
131 def __init__(
132 self,
133 headers: Headers,
134 stream: AsyncGenerator[bytes, None],
135 *,
136 max_files: int | float = 1000,
137 max_fields: int | float = 1000,
138 max_part_size: int = 1024 * 1024, # 1MB
139 ) -> None:
140 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
141 self.headers = headers
142 self.stream = stream
143 self.max_files = max_files
144 self.max_fields = max_fields
145 self.items: list[tuple[str, str | UploadFile]] = []
146 self._current_files = 0
147 self._current_fields = 0
148 self._current_partial_header_name: bytes = b""
149 self._current_partial_header_value: bytes = b""
150 self._current_part = MultipartPart()
151 self._charset = ""
152 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
153 self._file_parts_to_finish: list[MultipartPart] = []
154 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
155 self.max_part_size = max_part_size
156
157 def on_part_begin(self) -> None:
158 self._current_part = MultipartPart()
159
160 def on_part_data(self, data: bytes, start: int, end: int) -> None:
161 message_bytes = data[start:end]
162 if self._current_part.file is None:
163 if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
164 raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
165 self._current_part.data.extend(message_bytes)
166 else:
167 self._file_parts_to_write.append((self._current_part, message_bytes))
168
169 def on_part_end(self) -> None:
170 if self._current_part.file is None:
171 self.items.append(
172 (
173 self._current_part.field_name,
174 _user_safe_decode(self._current_part.data, self._charset),
175 )
176 )
177 else:
178 self._file_parts_to_finish.append(self._current_part)
179 # The file can be added to the items right now even though it's not
180 # finished yet, because it will be finished in the `parse()` method, before
181 # self.items is used in the return value.
182 self.items.append((self._current_part.field_name, self._current_part.file))
183
184 def on_header_field(self, data: bytes, start: int, end: int) -> None:
185 self._current_partial_header_name += data[start:end]
186
187 def on_header_value(self, data: bytes, start: int, end: int) -> None:
188 self._current_partial_header_value += data[start:end]
189
190 def on_header_end(self) -> None:
191 field = self._current_partial_header_name.lower()
192 if field == b"content-disposition":
193 self._current_part.content_disposition = self._current_partial_header_value
194 self._current_part.item_headers.append((field, self._current_partial_header_value))
195 self._current_partial_header_name = b""
196 self._current_partial_header_value = b""
197
198 def on_headers_finished(self) -> None:
199 disposition, options = parse_options_header(self._current_part.content_disposition)
200 try:
201 self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset)
202 except KeyError:
203 raise MultiPartException('The Content-Disposition header field "name" must be provided.')
204 if b"filename" in options:
205 self._current_files += 1
206 if self._current_files > self.max_files:
207 raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
208 filename = _user_safe_decode(options[b"filename"], self._charset)
209 tempfile = SpooledTemporaryFile(max_size=self.spool_max_size)
210 self._files_to_close_on_error.append(tempfile)
211 self._current_part.file = UploadFile(
212 file=tempfile, # type: ignore[arg-type]
213 size=0,
214 filename=filename,
215 headers=Headers(raw=self._current_part.item_headers),
216 )
217 else:
218 self._current_fields += 1
219 if self._current_fields > self.max_fields:
220 raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.")
221 self._current_part.file = None
222
223 def on_end(self) -> None:
224 pass
225
226 async def parse(self) -> FormData:
227 # Parse the Content-Type header to get the multipart boundary.
228 _, params = parse_options_header(self.headers["Content-Type"])
229 charset = params.get(b"charset", "utf-8")
230 if isinstance(charset, bytes):
231 charset = charset.decode("latin-1")
232 self._charset = charset
233 try:
234 boundary = params[b"boundary"]
235 except KeyError:
236 raise MultiPartException("Missing boundary in multipart.")
237
238 # Callbacks dictionary.
239 callbacks: MultipartCallbacks = {
240 "on_part_begin": self.on_part_begin,
241 "on_part_data": self.on_part_data,
242 "on_part_end": self.on_part_end,
243 "on_header_field": self.on_header_field,
244 "on_header_value": self.on_header_value,
245 "on_header_end": self.on_header_end,
246 "on_headers_finished": self.on_headers_finished,
247 "on_end": self.on_end,
248 }
249
250 # Create the parser.
251 parser = multipart.MultipartParser(boundary, callbacks)
252 try:
253 # Feed the parser with data from the request.
254 async for chunk in self.stream:
255 parser.write(chunk)
256 # Write file data, it needs to use await with the UploadFile methods
257 # that call the corresponding file methods *in a threadpool*,
258 # otherwise, if they were called directly in the callback methods above
259 # (regular, non-async functions), that would block the event loop in
260 # the main thread.
261 for part, data in self._file_parts_to_write:
262 assert part.file # for type checkers
263 await part.file.write(data)
264 for part in self._file_parts_to_finish:
265 assert part.file # for type checkers
266 await part.file.seek(0)
267 self._file_parts_to_write.clear()
268 self._file_parts_to_finish.clear()
269 except MultiPartException as exc:
270 # Close all the files if there was an error.
271 for file in self._files_to_close_on_error:
272 file.close()
273 raise exc
274
275 parser.finalize()
276 return FormData(self.items)