Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/formparsers.py: 28%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

167 statements  

1from __future__ import annotations 

2 

3import typing 

4from dataclasses import dataclass, field 

5from enum import Enum 

6from tempfile import SpooledTemporaryFile 

7from urllib.parse import unquote_plus 

8 

9from starlette.datastructures import FormData, Headers, UploadFile 

10 

11try: 

12 import multipart 

13 from multipart.multipart import parse_options_header 

14except ModuleNotFoundError: # pragma: nocover 

15 parse_options_header = None 

16 multipart = None 

17 

18 

19class FormMessage(Enum): 

20 FIELD_START = 1 

21 FIELD_NAME = 2 

22 FIELD_DATA = 3 

23 FIELD_END = 4 

24 END = 5 

25 

26 

27@dataclass 

28class MultipartPart: 

29 content_disposition: bytes | None = None 

30 field_name: str = "" 

31 data: bytes = b"" 

32 file: UploadFile | None = None 

33 item_headers: list[tuple[bytes, bytes]] = field(default_factory=list) 

34 

35 

36def _user_safe_decode(src: bytes, codec: str) -> str: 

37 try: 

38 return src.decode(codec) 

39 except (UnicodeDecodeError, LookupError): 

40 return src.decode("latin-1") 

41 

42 

43class MultiPartException(Exception): 

44 def __init__(self, message: str) -> None: 

45 self.message = message 

46 

47 

48class FormParser: 

49 def __init__(self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]) -> None: 

50 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." 

51 self.headers = headers 

52 self.stream = stream 

53 self.messages: list[tuple[FormMessage, bytes]] = [] 

54 

55 def on_field_start(self) -> None: 

56 message = (FormMessage.FIELD_START, b"") 

57 self.messages.append(message) 

58 

59 def on_field_name(self, data: bytes, start: int, end: int) -> None: 

60 message = (FormMessage.FIELD_NAME, data[start:end]) 

61 self.messages.append(message) 

62 

63 def on_field_data(self, data: bytes, start: int, end: int) -> None: 

64 message = (FormMessage.FIELD_DATA, data[start:end]) 

65 self.messages.append(message) 

66 

67 def on_field_end(self) -> None: 

68 message = (FormMessage.FIELD_END, b"") 

69 self.messages.append(message) 

70 

71 def on_end(self) -> None: 

72 message = (FormMessage.END, b"") 

73 self.messages.append(message) 

74 

75 async def parse(self) -> FormData: 

76 # Callbacks dictionary. 

77 callbacks = { 

78 "on_field_start": self.on_field_start, 

79 "on_field_name": self.on_field_name, 

80 "on_field_data": self.on_field_data, 

81 "on_field_end": self.on_field_end, 

82 "on_end": self.on_end, 

83 } 

84 

85 # Create the parser. 

86 parser = multipart.QuerystringParser(callbacks) 

87 field_name = b"" 

88 field_value = b"" 

89 

90 items: list[tuple[str, str | UploadFile]] = [] 

91 

92 # Feed the parser with data from the request. 

93 async for chunk in self.stream: 

94 if chunk: 

95 parser.write(chunk) 

96 else: 

97 parser.finalize() 

98 messages = list(self.messages) 

99 self.messages.clear() 

100 for message_type, message_bytes in messages: 

101 if message_type == FormMessage.FIELD_START: 

102 field_name = b"" 

103 field_value = b"" 

104 elif message_type == FormMessage.FIELD_NAME: 

105 field_name += message_bytes 

106 elif message_type == FormMessage.FIELD_DATA: 

107 field_value += message_bytes 

108 elif message_type == FormMessage.FIELD_END: 

109 name = unquote_plus(field_name.decode("latin-1")) 

110 value = unquote_plus(field_value.decode("latin-1")) 

111 items.append((name, value)) 

112 

113 return FormData(items) 

114 

115 

116class MultiPartParser: 

117 max_file_size = 1024 * 1024 

118 

119 def __init__( 

120 self, 

121 headers: Headers, 

122 stream: typing.AsyncGenerator[bytes, None], 

123 *, 

124 max_files: int | float = 1000, 

125 max_fields: int | float = 1000, 

126 ) -> None: 

127 assert multipart is not None, "The `python-multipart` library must be installed to use form parsing." 

128 self.headers = headers 

129 self.stream = stream 

130 self.max_files = max_files 

131 self.max_fields = max_fields 

132 self.items: list[tuple[str, str | UploadFile]] = [] 

133 self._current_files = 0 

134 self._current_fields = 0 

135 self._current_partial_header_name: bytes = b"" 

136 self._current_partial_header_value: bytes = b"" 

137 self._current_part = MultipartPart() 

138 self._charset = "" 

139 self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = [] 

140 self._file_parts_to_finish: list[MultipartPart] = [] 

141 self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = [] 

142 

143 def on_part_begin(self) -> None: 

144 self._current_part = MultipartPart() 

145 

146 def on_part_data(self, data: bytes, start: int, end: int) -> None: 

147 message_bytes = data[start:end] 

148 if self._current_part.file is None: 

149 self._current_part.data += message_bytes 

150 else: 

151 self._file_parts_to_write.append((self._current_part, message_bytes)) 

152 

153 def on_part_end(self) -> None: 

154 if self._current_part.file is None: 

155 self.items.append( 

156 ( 

157 self._current_part.field_name, 

158 _user_safe_decode(self._current_part.data, self._charset), 

159 ) 

160 ) 

161 else: 

162 self._file_parts_to_finish.append(self._current_part) 

163 # The file can be added to the items right now even though it's not 

164 # finished yet, because it will be finished in the `parse()` method, before 

165 # self.items is used in the return value. 

166 self.items.append((self._current_part.field_name, self._current_part.file)) 

167 

168 def on_header_field(self, data: bytes, start: int, end: int) -> None: 

169 self._current_partial_header_name += data[start:end] 

170 

171 def on_header_value(self, data: bytes, start: int, end: int) -> None: 

172 self._current_partial_header_value += data[start:end] 

173 

174 def on_header_end(self) -> None: 

175 field = self._current_partial_header_name.lower() 

176 if field == b"content-disposition": 

177 self._current_part.content_disposition = self._current_partial_header_value 

178 self._current_part.item_headers.append((field, self._current_partial_header_value)) 

179 self._current_partial_header_name = b"" 

180 self._current_partial_header_value = b"" 

181 

182 def on_headers_finished(self) -> None: 

183 disposition, options = parse_options_header(self._current_part.content_disposition) 

184 try: 

185 self._current_part.field_name = _user_safe_decode(options[b"name"], self._charset) 

186 except KeyError: 

187 raise MultiPartException('The Content-Disposition header field "name" must be provided.') 

188 if b"filename" in options: 

189 self._current_files += 1 

190 if self._current_files > self.max_files: 

191 raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.") 

192 filename = _user_safe_decode(options[b"filename"], self._charset) 

193 tempfile = SpooledTemporaryFile(max_size=self.max_file_size) 

194 self._files_to_close_on_error.append(tempfile) 

195 self._current_part.file = UploadFile( 

196 file=tempfile, # type: ignore[arg-type] 

197 size=0, 

198 filename=filename, 

199 headers=Headers(raw=self._current_part.item_headers), 

200 ) 

201 else: 

202 self._current_fields += 1 

203 if self._current_fields > self.max_fields: 

204 raise MultiPartException(f"Too many fields. Maximum number of fields is {self.max_fields}.") 

205 self._current_part.file = None 

206 

207 def on_end(self) -> None: 

208 pass 

209 

210 async def parse(self) -> FormData: 

211 # Parse the Content-Type header to get the multipart boundary. 

212 _, params = parse_options_header(self.headers["Content-Type"]) 

213 charset = params.get(b"charset", "utf-8") 

214 if isinstance(charset, bytes): 

215 charset = charset.decode("latin-1") 

216 self._charset = charset 

217 try: 

218 boundary = params[b"boundary"] 

219 except KeyError: 

220 raise MultiPartException("Missing boundary in multipart.") 

221 

222 # Callbacks dictionary. 

223 callbacks = { 

224 "on_part_begin": self.on_part_begin, 

225 "on_part_data": self.on_part_data, 

226 "on_part_end": self.on_part_end, 

227 "on_header_field": self.on_header_field, 

228 "on_header_value": self.on_header_value, 

229 "on_header_end": self.on_header_end, 

230 "on_headers_finished": self.on_headers_finished, 

231 "on_end": self.on_end, 

232 } 

233 

234 # Create the parser. 

235 parser = multipart.MultipartParser(boundary, callbacks) 

236 try: 

237 # Feed the parser with data from the request. 

238 async for chunk in self.stream: 

239 parser.write(chunk) 

240 # Write file data, it needs to use await with the UploadFile methods 

241 # that call the corresponding file methods *in a threadpool*, 

242 # otherwise, if they were called directly in the callback methods above 

243 # (regular, non-async functions), that would block the event loop in 

244 # the main thread. 

245 for part, data in self._file_parts_to_write: 

246 assert part.file # for type checkers 

247 await part.file.write(data) 

248 for part in self._file_parts_to_finish: 

249 assert part.file # for type checkers 

250 await part.file.seek(0) 

251 self._file_parts_to_write.clear() 

252 self._file_parts_to_finish.clear() 

253 except MultiPartException as exc: 

254 # Close all the files if there was an error. 

255 for file in self._files_to_close_on_error: 

256 file.close() 

257 raise exc 

258 

259 parser.finalize() 

260 return FormData(self.items)