Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/formparsers.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

178 statements  

1from __future__ import annotations 

2 

3from collections.abc import AsyncGenerator 

4from dataclasses import dataclass, field 

5from enum import Enum 

6from tempfile import SpooledTemporaryFile 

7from typing import TYPE_CHECKING 

8from urllib.parse import unquote_plus 

9 

10from starlette.datastructures import FormData, Headers, UploadFile 

11 

12if TYPE_CHECKING: 

13 import python_multipart as multipart 

14 from python_multipart.multipart import MultipartCallbacks, QuerystringCallbacks, parse_options_header 

15else: 

16 try: 

17 try: 

18 import python_multipart as multipart 

19 from python_multipart.multipart import parse_options_header 

20 except ModuleNotFoundError: # pragma: no cover 

21 import multipart 

22 from multipart.multipart import parse_options_header 

23 except ModuleNotFoundError: # pragma: no cover 

24 multipart = None 

25 parse_options_header = None 

26 

27 

28class FormMessage(Enum): 

29 FIELD_START = 1 

30 FIELD_NAME = 2 

31 FIELD_DATA = 3 

32 FIELD_END = 4 

33 END = 5 

34 

35 

36@dataclass 

37class MultipartPart: 

38 content_disposition: bytes | None = None 

39 field_name: str = "" 

40 data: bytearray = field(default_factory=bytearray) 

41 file: UploadFile | None = None 

42 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) 

43 

44 

45def _user_safe_decode(src: bytes | bytearray, codec: str) -> str: 

46 try: 

47 return src.decode(codec) 

48 except (UnicodeDecodeError, LookupError): 

49 return src.decode("latin-1") 

50 

51 

52class MultiPartException(Exception): 

53 def __init__(self, message: str) -> None: 

54 self.message = message 

55 

56 

57class FormParser: 

58 def __init__(self, headers: Headers, stream: AsyncGenerator[bytes, None]) -> None: 

59 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." 

60 self.headers = headers 

61 self.stream = stream 

62 self.messages: list[tuple[FormMessage, bytes]] = [] 

63 

64 def on_field_start(self) -> None: 

65 message = (FormMessage.FIELD_START, b"") 

66 self.messages.append(message) 

67 

68 def on_field_name(self, data: bytes, start: int, end: int) -> None: 

69 message = (FormMessage.FIELD_NAME, data[start:end]) 

70 self.messages.append(message) 

71 

72 def on_field_data(self, data: bytes, start: int, end: int) -> None: 

73 message = (FormMessage.FIELD_DATA, data[start:end]) 

74 self.messages.append(message) 

75 

76 def on_field_end(self) -> None: 

77 message = (FormMessage.FIELD_END, b"") 

78 self.messages.append(message) 

79 

80 def on_end(self) -> None: 

81 message = (FormMessage.END, b"") 

82 self.messages.append(message) 

83 

84 async def parse(self) -> FormData: 

85 # Callbacks dictionary. 

86 callbacks: QuerystringCallbacks = { 

87 "on_field_start": self.on_field_start, 

88 "on_field_name": self.on_field_name, 

89 "on_field_data": self.on_field_data, 

90 "on_field_end": self.on_field_end, 

91 "on_end": self.on_end, 

92 } 

93 

94 # Create the parser. 

95 parser = multipart.QuerystringParser(callbacks) 

96 field_name = b"" 

97 field_value = b"" 

98 

99 items: list[tuple[str, str | UploadFile]] = [] 

100 

101 # Feed the parser with data from the request. 

102 async for chunk in self.stream: 

103 if chunk: 

104 parser.write(chunk) 

105 else: 

106 parser.finalize() 

107 messages = list(self.messages) 

108 self.messages.clear() 

109 for message_type, message_bytes in messages: 

110 if message_type == FormMessage.FIELD_START: 

111 field_name = b"" 

112 field_value = b"" 

113 elif message_type == FormMessage.FIELD_NAME: 

114 field_name += message_bytes 

115 elif message_type == FormMessage.FIELD_DATA: 

116 field_value += message_bytes 

117 elif message_type == FormMessage.FIELD_END: 

118 name = unquote_plus(field_name.decode("latin-1")) 

119 value = unquote_plus(field_value.decode("latin-1")) 

120 items.append((name, value)) 

121 

122 return FormData(items) 

123 

124 

125class MultiPartParser: 

126 spool_max_size = 1024 * 1024 # 1MB 

127 """The maximum size of the spooled temporary file used to store file data.""" 

128 max_part_size = 1024 * 1024 # 1MB 

129 """The maximum size of a part in the multipart request.""" 

130 

131 def __init__( 

132 self, 

133 headers: Headers, 

134 stream: AsyncGenerator[bytes, None], 

135 *, 

136 max_files: int | float = 1000, 

137 max_fields: int | float = 1000, 

138 max_part_size: int = 1024 * 1024, # 1MB 

139 ) -> None: 

140 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." 

141 self.headers = headers 

142 self.stream = stream 

143 self.max_files = max_files 

144 self.max_fields = max_fields 

145 self.items: list[tuple[str, str | UploadFile]] = [] 

146 self._current_files = 0 

147 self._current_fields = 0 

148 self._current_partial_header_name: bytes = b"" 

149 self._current_partial_header_value: bytes = b"" 

150 self._current_part = MultipartPart() 

151 self._charset = "" 

152 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] 

153 self._file_parts_to_finish: list[MultipartPart] = [] 

154 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] 

155 self.max_part_size = max_part_size 

156 

157 def on_part_begin(self) -> None: 

158 self._current_part = MultipartPart() 

159 

160 def on_part_data(self, data: bytes, start: int, end: int) -> None: 

161 message_bytes = data[start:end] 

162 if self._current_part.file is None: 

163 if len(self._current_part.data) + len(message_bytes) > self.max_part_size: 

164 raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.") 

165 self._current_part.data.extend(message_bytes) 

166 else: 

167 self._file_parts_to_write.append((self._current_part, message_bytes)) 

168 

169 def on_part_end(self) -> None: 

170 if self._current_part.file is None: 

171 self.items.append( 

172 ( 

173 self._current_part.field_name, 

174 _user_safe_decode(self._current_part.data, self._charset), 

175 ) 

176 ) 

177 else: 

178 self._file_parts_to_finish.append(self._current_part) 

179 # The file can be added to the items right now even though it's not 

180 # finished yet, because it will be finished in the `parse()` method, before 

181 # self.items is used in the return value. 

182 self.items.append((self._current_part.field_name, self._current_part.file)) 

183 

184 def on_header_field(self, data: bytes, start: int, end: int) -> None: 

185 self._current_partial_header_name += data[start:end] 

186 

187 def on_header_value(self, data: bytes, start: int, end: int) -> None: 

188 self._current_partial_header_value += data[start:end] 

189 

190 def on_header_end(self) -> None: 

191 field = self._current_partial_header_name.lower() 

192 if field == b"content-disposition": 

193 self._current_part.content_disposition = self._current_partial_header_value 

194 self._current_part.item_headers.append((field, self._current_partial_header_value)) 

195 self._current_partial_header_name = b"" 

196 self._current_partial_header_value = b"" 

197 

198 def on_headers_finished(self) -> None: 

199 disposition, options = parse_options_header(self._current_part.content_disposition) 

200 try: 

201 self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) 

202 except KeyError: 

203 raise MultiPartException('The Content-Disposition header field "name" must be provided.') 

204 if b"filename" in options: 

205 self._current_files += 1 

206 if self._current_files > self.max_files: 

207 raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") 

208 filename = _user_safe_decode(options[b"filename"], self._charset) 

209 tempfile = SpooledTemporaryFile(max_size=self.spool_max_size) 

210 self._files_to_close_on_error.append(tempfile) 

211 self._current_part.file = UploadFile( 

212 file=tempfile, # type: ignore[arg-type] 

213 size=0, 

214 filename=filename, 

215 headers=Headers(raw=self._current_part.item_headers), 

216 ) 

217 else: 

218 self._current_fields += 1 

219 if self._current_fields > self.max_fields: 

220 raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") 

221 self._current_part.file = None 

222 

223 def on_end(self) -> None: 

224 pass 

225 

226 async def parse(self) -> FormData: 

227 # Parse the Content-Type header to get the multipart boundary. 

228 _, params = parse_options_header(self.headers["Content-Type"]) 

229 charset = params.get(b"charset", "utf-8") 

230 if isinstance(charset, bytes): 

231 charset = charset.decode("latin-1") 

232 self._charset = charset 

233 try: 

234 boundary = params[b"boundary"] 

235 except KeyError: 

236 raise MultiPartException("Missing boundary in multipart.") 

237 

238 # Callbacks dictionary. 

239 callbacks: MultipartCallbacks = { 

240 "on_part_begin": self.on_part_begin, 

241 "on_part_data": self.on_part_data, 

242 "on_part_end": self.on_part_end, 

243 "on_header_field": self.on_header_field, 

244 "on_header_value": self.on_header_value, 

245 "on_header_end": self.on_header_end, 

246 "on_headers_finished": self.on_headers_finished, 

247 "on_end": self.on_end, 

248 } 

249 

250 # Create the parser. 

251 parser = multipart.MultipartParser(boundary, callbacks) 

252 try: 

253 # Feed the parser with data from the request. 

254 async for chunk in self.stream: 

255 parser.write(chunk) 

256 # Write file data, it needs to use await with the UploadFile methods 

257 # that call the corresponding file methods *in a threadpool*, 

258 # otherwise, if they were called directly in the callback methods above 

259 # (regular, non-async functions), that would block the event loop in 

260 # the main thread. 

261 for part, data in self._file_parts_to_write: 

262 assert part.file # for type checkers 

263 await part.file.write(data) 

264 for part in self._file_parts_to_finish: 

265 assert part.file # for type checkers 

266 await part.file.seek(0) 

267 self._file_parts_to_write.clear() 

268 self._file_parts_to_finish.clear() 

269 except MultiPartException as exc: 

270 # Close all the files if there was an error. 

271 for file in self._files_to_close_on_error: 

272 file.close() 

273 raise exc 

274 

275 parser.finalize() 

276 return FormData(self.items)