Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/formparsers.py: 27%
166 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import typing
2from dataclasses import dataclass, field
3from enum import Enum
4from tempfile import SpooledTemporaryFile
5from urllib.parse import unquote_plus
7from starlette.datastructures import FormData, Headers, UploadFile
9try:
10 import multipart
11 from multipart.multipart import parse_options_header
12except ImportError: # pragma: nocover
13 parse_options_header = None
14 multipart = None
17class FormMessage(Enum):
18 FIELD_START = 1
19 FIELD_NAME = 2
20 FIELD_DATA = 3
21 FIELD_END = 4
22 END = 5
25@dataclass
26class MultipartPart:
27 content_disposition: typing.Optional[bytes] = None
28 field_name: str = ""
29 data: bytes = b""
30 file: typing.Optional[UploadFile] = None
31 item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)
34def _user_safe_decode(src: bytes, codec: str) -> str:
35 try:
36 return src.decode(codec)
37 except (UnicodeDecodeError, LookupError):
38 return src.decode("latin-1")
41class MultiPartException(Exception):
42 def __init__(self, message: str) -> None:
43 self.message = message
46class FormParser:
47 def __init__(
48 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
49 ) -> None:
50 assert (
51 multipart is not None
52 ), "The `python-multipart` library must be installed to use form parsing."
53 self.headers = headers
54 self.stream = stream
55 self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = []
57 def on_field_start(self) -> None:
58 message = (FormMessage.FIELD_START, b"")
59 self.messages.append(message)
61 def on_field_name(self, data: bytes, start: int, end: int) -> None:
62 message = (FormMessage.FIELD_NAME, data[start:end])
63 self.messages.append(message)
65 def on_field_data(self, data: bytes, start: int, end: int) -> None:
66 message = (FormMessage.FIELD_DATA, data[start:end])
67 self.messages.append(message)
69 def on_field_end(self) -> None:
70 message = (FormMessage.FIELD_END, b"")
71 self.messages.append(message)
73 def on_end(self) -> None:
74 message = (FormMessage.END, b"")
75 self.messages.append(message)
77 async def parse(self) -> FormData:
78 # Callbacks dictionary.
79 callbacks = {
80 "on_field_start": self.on_field_start,
81 "on_field_name": self.on_field_name,
82 "on_field_data": self.on_field_data,
83 "on_field_end": self.on_field_end,
84 "on_end": self.on_end,
85 }
87 # Create the parser.
88 parser = multipart.QuerystringParser(callbacks)
89 field_name = b""
90 field_value = b""
92 items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
94 # Feed the parser with data from the request.
95 async for chunk in self.stream:
96 if chunk:
97 parser.write(chunk)
98 else:
99 parser.finalize()
100 messages = list(self.messages)
101 self.messages.clear()
102 for message_type, message_bytes in messages:
103 if message_type == FormMessage.FIELD_START:
104 field_name = b""
105 field_value = b""
106 elif message_type == FormMessage.FIELD_NAME:
107 field_name += message_bytes
108 elif message_type == FormMessage.FIELD_DATA:
109 field_value += message_bytes
110 elif message_type == FormMessage.FIELD_END:
111 name = unquote_plus(field_name.decode("latin-1"))
112 value = unquote_plus(field_value.decode("latin-1"))
113 items.append((name, value))
115 return FormData(items)
118class MultiPartParser:
119 max_file_size = 1024 * 1024
121 def __init__(
122 self,
123 headers: Headers,
124 stream: typing.AsyncGenerator[bytes, None],
125 *,
126 max_files: typing.Union[int, float] = 1000,
127 max_fields: typing.Union[int, float] = 1000,
128 ) -> None:
129 assert (
130 multipart is not None
131 ), "The `python-multipart` library must be installed to use form parsing."
132 self.headers = headers
133 self.stream = stream
134 self.max_files = max_files
135 self.max_fields = max_fields
136 self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
137 self._current_files = 0
138 self._current_fields = 0
139 self._current_partial_header_name: bytes = b""
140 self._current_partial_header_value: bytes = b""
141 self._current_part = MultipartPart()
142 self._charset = ""
143 self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
144 self._file_parts_to_finish: typing.List[MultipartPart] = []
145 self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = []
147 def on_part_begin(self) -> None:
148 self._current_part = MultipartPart()
150 def on_part_data(self, data: bytes, start: int, end: int) -> None:
151 message_bytes = data[start:end]
152 if self._current_part.file is None:
153 self._current_part.data += message_bytes
154 else:
155 self._file_parts_to_write.append((self._current_part, message_bytes))
157 def on_part_end(self) -> None:
158 if self._current_part.file is None:
159 self.items.append(
160 (
161 self._current_part.field_name,
162 _user_safe_decode(self._current_part.data, self._charset),
163 )
164 )
165 else:
166 self._file_parts_to_finish.append(self._current_part)
167 # The file can be added to the items right now even though it's not
168 # finished yet, because it will be finished in the `parse()` method, before
169 # self.items is used in the return value.
170 self.items.append((self._current_part.field_name, self._current_part.file))
172 def on_header_field(self, data: bytes, start: int, end: int) -> None:
173 self._current_partial_header_name += data[start:end]
175 def on_header_value(self, data: bytes, start: int, end: int) -> None:
176 self._current_partial_header_value += data[start:end]
178 def on_header_end(self) -> None:
179 field = self._current_partial_header_name.lower()
180 if field == b"content-disposition":
181 self._current_part.content_disposition = self._current_partial_header_value
182 self._current_part.item_headers.append(
183 (field, self._current_partial_header_value)
184 )
185 self._current_partial_header_name = b""
186 self._current_partial_header_value = b""
188 def on_headers_finished(self) -> None:
189 disposition, options = parse_options_header(
190 self._current_part.content_disposition
191 )
192 try:
193 self._current_part.field_name = _user_safe_decode(
194 options[b"name"], self._charset
195 )
196 except KeyError:
197 raise MultiPartException(
198 'The Content-Disposition header field "name" must be ' "provided."
199 )
200 if b"filename" in options:
201 self._current_files += 1
202 if self._current_files > self.max_files:
203 raise MultiPartException(
204 f"Too many files. Maximum number of files is {self.max_files}."
205 )
206 filename = _user_safe_decode(options[b"filename"], self._charset)
207 tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
208 self._files_to_close_on_error.append(tempfile)
209 self._current_part.file = UploadFile(
210 file=tempfile, # type: ignore[arg-type]
211 size=0,
212 filename=filename,
213 headers=Headers(raw=self._current_part.item_headers),
214 )
215 else:
216 self._current_fields += 1
217 if self._current_fields > self.max_fields:
218 raise MultiPartException(
219 f"Too many fields. Maximum number of fields is {self.max_fields}."
220 )
221 self._current_part.file = None
223 def on_end(self) -> None:
224 pass
226 async def parse(self) -> FormData:
227 # Parse the Content-Type header to get the multipart boundary.
228 _, params = parse_options_header(self.headers["Content-Type"])
229 charset = params.get(b"charset", "utf-8")
230 if type(charset) == bytes:
231 charset = charset.decode("latin-1")
232 self._charset = charset
233 try:
234 boundary = params[b"boundary"]
235 except KeyError:
236 raise MultiPartException("Missing boundary in multipart.")
238 # Callbacks dictionary.
239 callbacks = {
240 "on_part_begin": self.on_part_begin,
241 "on_part_data": self.on_part_data,
242 "on_part_end": self.on_part_end,
243 "on_header_field": self.on_header_field,
244 "on_header_value": self.on_header_value,
245 "on_header_end": self.on_header_end,
246 "on_headers_finished": self.on_headers_finished,
247 "on_end": self.on_end,
248 }
250 # Create the parser.
251 parser = multipart.MultipartParser(boundary, callbacks)
252 try:
253 # Feed the parser with data from the request.
254 async for chunk in self.stream:
255 parser.write(chunk)
256 # Write file data, it needs to use await with the UploadFile methods
257 # that call the corresponding file methods *in a threadpool*,
258 # otherwise, if they were called directly in the callback methods above
259 # (regular, non-async functions), that would block the event loop in
260 # the main thread.
261 for part, data in self._file_parts_to_write:
262 assert part.file # for type checkers
263 await part.file.write(data)
264 for part in self._file_parts_to_finish:
265 assert part.file # for type checkers
266 await part.file.seek(0)
267 self._file_parts_to_write.clear()
268 self._file_parts_to_finish.clear()
269 except MultiPartException as exc:
270 # Close all the files if there was an error.
271 for file in self._files_to_close_on_error:
272 file.close()
273 raise exc
275 parser.finalize()
276 return FormData(self.items)