Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/formparsers.py: 27%

166 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import typing 

2from dataclasses import dataclass, field 

3from enum import Enum 

4from tempfile import SpooledTemporaryFile 

5from urllib.parse import unquote_plus 

6 

7from starlette.datastructures import FormData, Headers, UploadFile 

8 

9try: 

10 import multipart 

11 from multipart.multipart import parse_options_header 

12except ImportError: # pragma: nocover 

13 parse_options_header = None 

14 multipart = None 

15 

16 

17class FormMessage(Enum): 

18 FIELD_START = 1 

19 FIELD_NAME = 2 

20 FIELD_DATA = 3 

21 FIELD_END = 4 

22 END = 5 

23 

24 

25@dataclass 

26class MultipartPart: 

27 content_disposition: typing.Optional[bytes] = None 

28 field_name: str = "" 

29 data: bytes = b"" 

30 file: typing.Optional[UploadFile] = None 

31 item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list) 

32 

33 

34def _user_safe_decode(src: bytes, codec: str) -> str: 

35 try: 

36 return src.decode(codec) 

37 except (UnicodeDecodeError, LookupError): 

38 return src.decode("latin-1") 

39 

40 

41class MultiPartException(Exception): 

42 def __init__(self, message: str) -> None: 

43 self.message = message 

44 

45 

46class FormParser: 

47 def __init__( 

48 self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] 

49 ) -> None: 

50 assert ( 

51 multipart is not None 

52 ), "The `python-multipart` library must be installed to use form parsing." 

53 self.headers = headers 

54 self.stream = stream 

55 self.messages: typing.List[typing.Tuple[FormMessage, bytes]] = [] 

56 

57 def on_field_start(self) -> None: 

58 message = (FormMessage.FIELD_START, b"") 

59 self.messages.append(message) 

60 

61 def on_field_name(self, data: bytes, start: int, end: int) -> None: 

62 message = (FormMessage.FIELD_NAME, data[start:end]) 

63 self.messages.append(message) 

64 

65 def on_field_data(self, data: bytes, start: int, end: int) -> None: 

66 message = (FormMessage.FIELD_DATA, data[start:end]) 

67 self.messages.append(message) 

68 

69 def on_field_end(self) -> None: 

70 message = (FormMessage.FIELD_END, b"") 

71 self.messages.append(message) 

72 

73 def on_end(self) -> None: 

74 message = (FormMessage.END, b"") 

75 self.messages.append(message) 

76 

77 async def parse(self) -> FormData: 

78 # Callbacks dictionary. 

79 callbacks = { 

80 "on_field_start": self.on_field_start, 

81 "on_field_name": self.on_field_name, 

82 "on_field_data": self.on_field_data, 

83 "on_field_end": self.on_field_end, 

84 "on_end": self.on_end, 

85 } 

86 

87 # Create the parser. 

88 parser = multipart.QuerystringParser(callbacks) 

89 field_name = b"" 

90 field_value = b"" 

91 

92 items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] 

93 

94 # Feed the parser with data from the request. 

95 async for chunk in self.stream: 

96 if chunk: 

97 parser.write(chunk) 

98 else: 

99 parser.finalize() 

100 messages = list(self.messages) 

101 self.messages.clear() 

102 for message_type, message_bytes in messages: 

103 if message_type == FormMessage.FIELD_START: 

104 field_name = b"" 

105 field_value = b"" 

106 elif message_type == FormMessage.FIELD_NAME: 

107 field_name += message_bytes 

108 elif message_type == FormMessage.FIELD_DATA: 

109 field_value += message_bytes 

110 elif message_type == FormMessage.FIELD_END: 

111 name = unquote_plus(field_name.decode("latin-1")) 

112 value = unquote_plus(field_value.decode("latin-1")) 

113 items.append((name, value)) 

114 

115 return FormData(items) 

116 

117 

118class MultiPartParser: 

119 max_file_size = 1024 * 1024 

120 

121 def __init__( 

122 self, 

123 headers: Headers, 

124 stream: typing.AsyncGenerator[bytes, None], 

125 *, 

126 max_files: typing.Union[int, float] = 1000, 

127 max_fields: typing.Union[int, float] = 1000, 

128 ) -> None: 

129 assert ( 

130 multipart is not None 

131 ), "The `python-multipart` library must be installed to use form parsing." 

132 self.headers = headers 

133 self.stream = stream 

134 self.max_files = max_files 

135 self.max_fields = max_fields 

136 self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] 

137 self._current_files = 0 

138 self._current_fields = 0 

139 self._current_partial_header_name: bytes = b"" 

140 self._current_partial_header_value: bytes = b"" 

141 self._current_part = MultipartPart() 

142 self._charset = "" 

143 self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = [] 

144 self._file_parts_to_finish: typing.List[MultipartPart] = [] 

145 self._files_to_close_on_error: typing.List[SpooledTemporaryFile] = [] 

146 

147 def on_part_begin(self) -> None: 

148 self._current_part = MultipartPart() 

149 

150 def on_part_data(self, data: bytes, start: int, end: int) -> None: 

151 message_bytes = data[start:end] 

152 if self._current_part.file is None: 

153 self._current_part.data += message_bytes 

154 else: 

155 self._file_parts_to_write.append((self._current_part, message_bytes)) 

156 

157 def on_part_end(self) -> None: 

158 if self._current_part.file is None: 

159 self.items.append( 

160 ( 

161 self._current_part.field_name, 

162 _user_safe_decode(self._current_part.data, self._charset), 

163 ) 

164 ) 

165 else: 

166 self._file_parts_to_finish.append(self._current_part) 

167 # The file can be added to the items right now even though it's not 

168 # finished yet, because it will be finished in the `parse()` method, before 

169 # self.items is used in the return value. 

170 self.items.append((self._current_part.field_name, self._current_part.file)) 

171 

172 def on_header_field(self, data: bytes, start: int, end: int) -> None: 

173 self._current_partial_header_name += data[start:end] 

174 

175 def on_header_value(self, data: bytes, start: int, end: int) -> None: 

176 self._current_partial_header_value += data[start:end] 

177 

178 def on_header_end(self) -> None: 

179 field = self._current_partial_header_name.lower() 

180 if field == b"content-disposition": 

181 self._current_part.content_disposition = self._current_partial_header_value 

182 self._current_part.item_headers.append( 

183 (field, self._current_partial_header_value) 

184 ) 

185 self._current_partial_header_name = b"" 

186 self._current_partial_header_value = b"" 

187 

188 def on_headers_finished(self) -> None: 

189 disposition, options = parse_options_header( 

190 self._current_part.content_disposition 

191 ) 

192 try: 

193 self._current_part.field_name = _user_safe_decode( 

194 options[b"name"], self._charset 

195 ) 

196 except KeyError: 

197 raise MultiPartException( 

198 'The Content-Disposition header field "name" must be ' "provided." 

199 ) 

200 if b"filename" in options: 

201 self._current_files += 1 

202 if self._current_files > self.max_files: 

203 raise MultiPartException( 

204 f"Too many files. Maximum number of files is {self.max_files}." 

205 ) 

206 filename = _user_safe_decode(options[b"filename"], self._charset) 

207 tempfile = SpooledTemporaryFile(max_size=self.max_file_size) 

208 self._files_to_close_on_error.append(tempfile) 

209 self._current_part.file = UploadFile( 

210 file=tempfile, # type: ignore[arg-type] 

211 size=0, 

212 filename=filename, 

213 headers=Headers(raw=self._current_part.item_headers), 

214 ) 

215 else: 

216 self._current_fields += 1 

217 if self._current_fields > self.max_fields: 

218 raise MultiPartException( 

219 f"Too many fields. Maximum number of fields is {self.max_fields}." 

220 ) 

221 self._current_part.file = None 

222 

223 def on_end(self) -> None: 

224 pass 

225 

226 async def parse(self) -> FormData: 

227 # Parse the Content-Type header to get the multipart boundary. 

228 _, params = parse_options_header(self.headers["Content-Type"]) 

229 charset = params.get(b"charset", "utf-8") 

230 if type(charset) == bytes: 

231 charset = charset.decode("latin-1") 

232 self._charset = charset 

233 try: 

234 boundary = params[b"boundary"] 

235 except KeyError: 

236 raise MultiPartException("Missing boundary in multipart.") 

237 

238 # Callbacks dictionary. 

239 callbacks = { 

240 "on_part_begin": self.on_part_begin, 

241 "on_part_data": self.on_part_data, 

242 "on_part_end": self.on_part_end, 

243 "on_header_field": self.on_header_field, 

244 "on_header_value": self.on_header_value, 

245 "on_header_end": self.on_header_end, 

246 "on_headers_finished": self.on_headers_finished, 

247 "on_end": self.on_end, 

248 } 

249 

250 # Create the parser. 

251 parser = multipart.MultipartParser(boundary, callbacks) 

252 try: 

253 # Feed the parser with data from the request. 

254 async for chunk in self.stream: 

255 parser.write(chunk) 

256 # Write file data, it needs to use await with the UploadFile methods 

257 # that call the corresponding file methods *in a threadpool*, 

258 # otherwise, if they were called directly in the callback methods above 

259 # (regular, non-async functions), that would block the event loop in 

260 # the main thread. 

261 for part, data in self._file_parts_to_write: 

262 assert part.file # for type checkers 

263 await part.file.write(data) 

264 for part in self._file_parts_to_finish: 

265 assert part.file # for type checkers 

266 await part.file.seek(0) 

267 self._file_parts_to_write.clear() 

268 self._file_parts_to_finish.clear() 

269 except MultiPartException as exc: 

270 # Close all the files if there was an error. 

271 for file in self._files_to_close_on_error: 

272 file.close() 

273 raise exc 

274 

275 parser.finalize() 

276 return FormData(self.items)