Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/formparsers.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

167 statements  

1from __future__ import annotations 

2 

3import typing 

4from dataclasses import dataclass, field 

5from enum import Enum 

6from tempfile import SpooledTemporaryFile 

7from urllib.parse import unquote_plus 

8 

9from starlette.datastructures import FormData, Headers, UploadFile 

10 

11try: 

12 import multipart 

13 from multipart.multipart import parse_options_header 

14except ModuleNotFoundError: # pragma: nocover 

15 parse_options_header = None 

16 multipart = None 

17 

18 

19class FormMessage(Enum): 

20 FIELD_START = 1 

21 FIELD_NAME = 2 

22 FIELD_DATA = 3 

23 FIELD_END = 4 

24 END = 5 

25 

26 

27@dataclass 

28class MultipartPart: 

29 content_disposition: bytes | None = None 

30 field_name: str = "" 

31 data: bytes = b"" 

32 file: UploadFile | None = None 

33 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) 

34 

35 

36def _user_safe_decode(src: bytes, codec: str) -> str: 

37 try: 

38 return src.decode(codec) 

39 except (UnicodeDecodeError, LookupError): 

40 return src.decode("latin-1") 

41 

42 

43class MultiPartException(Exception): 

44 def __init__(self, message: str) -> None: 

45 self.message = message 

46 

47 

48class FormParser: 

49 def __init__( 

50 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] 

51 ) -> None: 

52 assert ( 

53 multipart is not None 

54 ), "The `python-multipart` library must be installed to use form parsing." 

55 self.headers = headers 

56 self.stream = stream 

57 self.messages: list[tuple[FormMessage, bytes]] = [] 

58 

59 def on_field_start(self) -> None: 

60 message = (FormMessage.FIELD_START, b"") 

61 self.messages.append(message) 

62 

63 def on_field_name(self, data: bytes, start: int, end: int) -> None: 

64 message = (FormMessage.FIELD_NAME, data[start:end]) 

65 self.messages.append(message) 

66 

67 def on_field_data(self, data: bytes, start: int, end: int) -> None: 

68 message = (FormMessage.FIELD_DATA, data[start:end]) 

69 self.messages.append(message) 

70 

71 def on_field_end(self) -> None: 

72 message = (FormMessage.FIELD_END, b"") 

73 self.messages.append(message) 

74 

75 def on_end(self) -> None: 

76 message = (FormMessage.END, b"") 

77 self.messages.append(message) 

78 

79 async def parse(self) -> FormData: 

80 # Callbacks dictionary. 

81 callbacks = { 

82 "on_field_start": self.on_field_start, 

83 "on_field_name": self.on_field_name, 

84 "on_field_data": self.on_field_data, 

85 "on_field_end": self.on_field_end, 

86 "on_end": self.on_end, 

87 } 

88 

89 # Create the parser. 

90 parser = multipart.QuerystringParser(callbacks) 

91 field_name = b"" 

92 field_value = b"" 

93 

94 items: list[tuple[str, str | UploadFile]] = [] 

95 

96 # Feed the parser with data from the request. 

97 async for chunk in self.stream: 

98 if chunk: 

99 parser.write(chunk) 

100 else: 

101 parser.finalize() 

102 messages = list(self.messages) 

103 self.messages.clear() 

104 for message_type, message_bytes in messages: 

105 if message_type == FormMessage.FIELD_START: 

106 field_name = b"" 

107 field_value = b"" 

108 elif message_type == FormMessage.FIELD_NAME: 

109 field_name += message_bytes 

110 elif message_type == FormMessage.FIELD_DATA: 

111 field_value += message_bytes 

112 elif message_type == FormMessage.FIELD_END: 

113 name = unquote_plus(field_name.decode("latin-1")) 

114 value = unquote_plus(field_value.decode("latin-1")) 

115 items.append((name, value)) 

116 

117 return FormData(items) 

118 

119 

120class MultiPartParser: 

121 max_file_size = 1024 * 1024 

122 

123 def __init__( 

124 self, 

125 headers: Headers, 

126 stream: typing.AsyncGenerator[bytes, None], 

127 *, 

128 max_files: int | float = 1000, 

129 max_fields: int | float = 1000, 

130 ) -> None: 

131 assert ( 

132 multipart is not None 

133 ), "The `python-multipart` library must be installed to use form parsing." 

134 self.headers = headers 

135 self.stream = stream 

136 self.max_files = max_files 

137 self.max_fields = max_fields 

138 self.items: list[tuple[str, str | UploadFile]] = [] 

139 self._current_files = 0 

140 self._current_fields = 0 

141 self._current_partial_header_name: bytes = b"" 

142 self._current_partial_header_value: bytes = b"" 

143 self._current_part = MultipartPart() 

144 self._charset = "" 

145 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] 

146 self._file_parts_to_finish: list[MultipartPart] = [] 

147 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] 

148 

149 def on_part_begin(self) -> None: 

150 self._current_part = MultipartPart() 

151 

152 def on_part_data(self, data: bytes, start: int, end: int) -> None: 

153 message_bytes = data[start:end] 

154 if self._current_part.file is None: 

155 self._current_part.data += message_bytes 

156 else: 

157 self._file_parts_to_write.append((self._current_part, message_bytes)) 

158 

159 def on_part_end(self) -> None: 

160 if self._current_part.file is None: 

161 self.items.append( 

162 ( 

163 self._current_part.field_name, 

164 _user_safe_decode(self._current_part.data, self._charset), 

165 ) 

166 ) 

167 else: 

168 self._file_parts_to_finish.append(self._current_part) 

169 # The file can be added to the items right now even though it's not 

170 # finished yet, because it will be finished in the `parse()` method, before 

171 # self.items is used in the return value. 

172 self.items.append((self._current_part.field_name, self._current_part.file)) 

173 

174 def on_header_field(self, data: bytes, start: int, end: int) -> None: 

175 self._current_partial_header_name += data[start:end] 

176 

177 def on_header_value(self, data: bytes, start: int, end: int) -> None: 

178 self._current_partial_header_value += data[start:end] 

179 

180 def on_header_end(self) -> None: 

181 field = self._current_partial_header_name.lower() 

182 if field == b"content-disposition": 

183 self._current_part.content_disposition = self._current_partial_header_value 

184 self._current_part.item_headers.append( 

185 (field, self._current_partial_header_value) 

186 ) 

187 self._current_partial_header_name = b"" 

188 self._current_partial_header_value = b"" 

189 

190 def on_headers_finished(self) -> None: 

191 disposition, options = parse_options_header( 

192 self._current_part.content_disposition 

193 ) 

194 try: 

195 self._current_part.field_name = _user_safe_decode( 

196 options[b"name"], self._charset 

197 ) 

198 except KeyError: 

199 raise MultiPartException( 

200 'The Content-Disposition header field "name" must be ' "provided." 

201 ) 

202 if b"filename" in options: 

203 self._current_files += 1 

204 if self._current_files > self.max_files: 

205 raise MultiPartException( 

206 f"Too many files. Maximum number of files is {self.max_files}." 

207 ) 

208 filename = _user_safe_decode(options[b"filename"], self._charset) 

209 tempfile = SpooledTemporaryFile(max_size=self.max_file_size) 

210 self._files_to_close_on_error.append(tempfile) 

211 self._current_part.file = UploadFile( 

212 file=tempfile, # type: ignore[arg-type] 

213 size=0, 

214 filename=filename, 

215 headers=Headers(raw=self._current_part.item_headers), 

216 ) 

217 else: 

218 self._current_fields += 1 

219 if self._current_fields > self.max_fields: 

220 raise MultiPartException( 

221 f"Too many fields. Maximum number of fields is {self.max_fields}." 

222 ) 

223 self._current_part.file = None 

224 

225 def on_end(self) -> None: 

226 pass 

227 

228 async def parse(self) -> FormData: 

229 # Parse the Content-Type header to get the multipart boundary. 

230 _, params = parse_options_header(self.headers["Content-Type"]) 

231 charset = params.get(b"charset", "utf-8") 

232 if isinstance(charset, bytes): 

233 charset = charset.decode("latin-1") 

234 self._charset = charset 

235 try: 

236 boundary = params[b"boundary"] 

237 except KeyError: 

238 raise MultiPartException("Missing boundary in multipart.") 

239 

240 # Callbacks dictionary. 

241 callbacks = { 

242 "on_part_begin": self.on_part_begin, 

243 "on_part_data": self.on_part_data, 

244 "on_part_end": self.on_part_end, 

245 "on_header_field": self.on_header_field, 

246 "on_header_value": self.on_header_value, 

247 "on_header_end": self.on_header_end, 

248 "on_headers_finished": self.on_headers_finished, 

249 "on_end": self.on_end, 

250 } 

251 

252 # Create the parser. 

253 parser = multipart.MultipartParser(boundary, callbacks) 

254 try: 

255 # Feed the parser with data from the request. 

256 async for chunk in self.stream: 

257 parser.write(chunk) 

258 # Write file data, it needs to use await with the UploadFile methods 

259 # that call the corresponding file methods *in a threadpool*, 

260 # otherwise, if they were called directly in the callback methods above 

261 # (regular, non-async functions), that would block the event loop in 

262 # the main thread. 

263 for part, data in self._file_parts_to_write: 

264 assert part.file # for type checkers 

265 await part.file.write(data) 

266 for part in self._file_parts_to_finish: 

267 assert part.file # for type checkers 

268 await part.file.seek(0) 

269 self._file_parts_to_write.clear() 

270 self._file_parts_to_finish.clear() 

271 except MultiPartException as exc: 

272 # Close all the files if there was an error. 

273 for file in self._files_to_close_on_error: 

274 file.close() 

275 raise exc 

276 

277 parser.finalize() 

278 return FormData(self.items)