1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import sys
7import tempfile
8from enum import IntEnum
9from io import BufferedRandom, BytesIO
10from numbers import Number
11from typing import TYPE_CHECKING, cast
12
13from .decoders import Base64Decoder, QuotedPrintableDecoder
14from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
15
16if TYPE_CHECKING:
17 from collections.abc import Callable
18 from typing import Any, Literal, Protocol, TypeAlias, TypedDict
19
20 class SupportsRead(Protocol):
21 def read(self, __n: int) -> bytes: ...
22
23 class QuerystringCallbacks(TypedDict, total=False):
24 on_field_start: Callable[[], None]
25 on_field_name: Callable[[bytes, int, int], None]
26 on_field_data: Callable[[bytes, int, int], None]
27 on_field_end: Callable[[], None]
28 on_end: Callable[[], None]
29
30 class OctetStreamCallbacks(TypedDict, total=False):
31 on_start: Callable[[], None]
32 on_data: Callable[[bytes, int, int], None]
33 on_end: Callable[[], None]
34
35 class MultipartCallbacks(TypedDict, total=False):
36 on_part_begin: Callable[[], None]
37 on_part_data: Callable[[bytes, int, int], None]
38 on_part_end: Callable[[], None]
39 on_header_begin: Callable[[], None]
40 on_header_field: Callable[[bytes, int, int], None]
41 on_header_value: Callable[[bytes, int, int], None]
42 on_header_end: Callable[[], None]
43 on_headers_finished: Callable[[], None]
44 on_end: Callable[[], None]
45
46 class FileConfig(TypedDict, total=False):
47 UPLOAD_DIR: str | bytes | None
48 UPLOAD_DELETE_TMP: bool
49 UPLOAD_KEEP_FILENAME: bool
50 UPLOAD_KEEP_EXTENSIONS: bool
51 MAX_MEMORY_FILE_SIZE: int
52
53 class FormParserConfig(FileConfig):
54 UPLOAD_ERROR_ON_BAD_CTE: bool
55 MAX_BODY_SIZE: float
56 MAX_HEADER_COUNT: int
57 MAX_HEADER_SIZE: int
58
59 CallbackName: TypeAlias = Literal[
60 "start",
61 "data",
62 "end",
63 "field_start",
64 "field_name",
65 "field_data",
66 "field_end",
67 "part_begin",
68 "part_data",
69 "part_end",
70 "header_begin",
71 "header_field",
72 "header_value",
73 "header_end",
74 "headers_finished",
75 ]
76
77# Unique missing object.
78_missing = object()
79
80
81class QuerystringState(IntEnum):
82 """Querystring parser states.
83
84 These are used to keep track of the state of the parser, and are used to determine
85 what to do when new data is encountered.
86 """
87
88 BEFORE_FIELD = 0
89 FIELD_NAME = 1
90 FIELD_DATA = 2
91
92
93class MultipartState(IntEnum):
94 """Multipart parser states.
95
96 These are used to keep track of the state of the parser, and are used to determine
97 what to do when new data is encountered.
98 """
99
100 START = 0
101 START_BOUNDARY = 1
102 HEADER_FIELD_START = 2
103 HEADER_FIELD = 3
104 HEADER_VALUE_START = 4
105 HEADER_VALUE = 5
106 HEADER_VALUE_ALMOST_DONE = 6
107 HEADERS_ALMOST_DONE = 7
108 PART_DATA_START = 8
109 PART_DATA = 9
110 PART_DATA_END = 10
111 END_BOUNDARY = 11
112 END = 12
113
114
115# Flags for the multipart parser.
116FLAG_PART_BOUNDARY = 1
117FLAG_LAST_BOUNDARY = 2
118
119# Get constants. Since iterating over a str on Python 2 gives you a 1-length
120# string, but iterating over a bytes object on Python 3 gives you an integer,
121# we need to save these constants.
122CR = b"\r"[0]
123LF = b"\n"[0]
124COLON = b":"[0]
125SPACE = b" "[0]
126HYPHEN = b"-"[0]
127AMPERSAND = b"&"[0]
128LOWER_A = b"a"[0]
129LOWER_Z = b"z"[0]
130NULL = b"\x00"[0]
131
132# fmt: off
133# Mask for ASCII characters that can be http tokens.
134# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
135# and these: !#$%&'*+-.^_`|~
136TOKEN_CHARS = (
137 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
138 b"abcdefghijklmnopqrstuvwxyz"
139 b"0123456789"
140 b"!#$%&'*+-.^_`|~")
141TOKEN_CHARS_SET = frozenset(TOKEN_CHARS)
142# fmt: on
143
144DEFAULT_MAX_HEADER_COUNT = 8
145"""Default maximum number of headers allowed per multipart part."""
146
147DEFAULT_MAX_HEADER_SIZE = 4096 + 128
148"""Default maximum size of a single multipart header line, including syntax overhead."""
149
150MAX_BOUNDARY_LENGTH = 256
151"""Maximum allowed length of a multipart boundary.
152
153[RFC 2046 §5.1.1](https://datatracker.ietf.org/doc/html/rfc2046#section-5.1.1)
154recommends boundaries be at most 70 bytes. 256 bytes is generous headroom over
155every HTTP client.
156"""
157
158
159def _parseparam(s: str) -> list[str]:
160 # Vendored from the standard library's
161 # [`email.message._parseparam`](https://github.com/python/cpython/blob/v3.14.2/Lib/email/message.py#L73-L96)
162 # to split a header into its `;`-separated parts without treating a `;` inside a double-quoted string as a
163 # separator - and without the RFC 2231 decoding that `email.message.Message.get_params` would apply on top.
164 s = ";" + s
165 plist: list[str] = []
166 start = 0
167 while s.find(";", start) == start:
168 start += 1
169 end = s.find(";", start)
170 ind, diff = start, 0
171 while end > 0:
172 diff += s.count('"', ind, end) - s.count('\\"', ind, end)
173 if diff % 2 == 0:
174 break
175 end, ind = ind, s.find(";", end + 1)
176 if end < 0:
177 end = len(s)
178 i = s.find("=", start, end)
179 if i == -1:
180 f = s[start:end]
181 else:
182 f = s[start:i].rstrip().lower() + "=" + s[i + 1 : end].lstrip()
183 plist.append(f.strip())
184 start = end
185 return plist
186
187
188def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
189 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
190 if not value:
191 return (b"", {})
192
193 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
194 if isinstance(value, bytes): # pragma: no cover
195 value = value.decode("latin-1")
196
197 # For types
198 assert isinstance(value, str), "Value should be a string by now"
199
200 # If we have no options, return the string as-is.
201 if ";" not in value:
202 return (value.lower().strip().encode("latin-1"), {})
203
204 ctype, *segments = _parseparam(value)
205 options: dict[bytes, bytes] = {}
206 for segment in segments:
207 key, _, val = segment.partition("=")
208 # [RFC 7578 §4.2](https://datatracker.ietf.org/doc/html/rfc7578#section-4.2)
209 # forbids the RFC 5987/2231 extended syntax (`key*=`, `key*0`, ...) in
210 # multipart/form-data, so we ignore those parameters and keep the plain
211 # `key` authoritative.
212 if "*" in key:
213 continue
214 if len(val) >= 2 and val[0] == '"' and val[-1] == '"':
215 val = val[1:-1].replace("\\\\", "\\").replace('\\"', '"')
216 # Work around an IE6 bug where the full file path is sent instead of
217 # just the filename.
218 if key == "filename" and (val[1:3] == ":\\" or val[:2] == "\\\\"):
219 val = val.split("\\")[-1]
220 options[key.encode("latin-1")] = val.encode("latin-1")
221 return ctype.encode("latin-1"), options
222
223
224class Field:
225 """A Field object represents a (parsed) form field. It represents a single
226 field with a corresponding name and value.
227
228 The name that a :class:`Field` will be instantiated with is the same name
229 that would be found in the following HTML::
230
231 <input name="name_goes_here" type="text"/>
232
233 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
234 will be called when data is written to the Field, and when the Field is
235 finalized, respectively.
236
237 Args:
238 name: The name of the form field.
239 content_type: The value of the Content-Type header for this field.
240 """
241
242 def __init__(self, name: bytes | None, *, content_type: str | None = None) -> None:
243 self._name = name
244 self._value: list[bytes] = []
245 self._content_type = content_type
246
247 # We cache the joined version of _value for speed.
248 self._cache = _missing
249
250 @classmethod
251 def from_value(cls, name: bytes, value: bytes | None) -> Field:
252 """Create an instance of a :class:`Field`, and set the corresponding
253 value - either None or an actual value. This method will also
254 finalize the Field itself.
255
256 Args:
257 name: the name of the form field.
258 value: the value of the form field - either a bytestring or None.
259
260 Returns:
261 A new instance of a [`Field`][python_multipart.Field].
262 """
263
264 f = cls(name)
265 if value is None:
266 f.set_none()
267 else:
268 f.write(value)
269 f.finalize()
270 return f
271
272 def write(self, data: bytes) -> int:
273 """Write some data into the form field.
274
275 Args:
276 data: The data to write to the field.
277
278 Returns:
279 The number of bytes written.
280 """
281 return self.on_data(data)
282
283 def on_data(self, data: bytes) -> int:
284 """This method is a callback that will be called whenever data is
285 written to the Field.
286
287 Args:
288 data: The data to write to the field.
289
290 Returns:
291 The number of bytes written.
292 """
293 self._value.append(data)
294 self._cache = _missing
295 return len(data)
296
297 def on_end(self) -> None:
298 """This method is called whenever the Field is finalized."""
299 if self._cache is _missing:
300 self._cache = b"".join(self._value)
301
302 def finalize(self) -> None:
303 """Finalize the form field."""
304 self.on_end()
305
306 def close(self) -> None:
307 """Close the Field object. This will free any underlying cache."""
308 # Free our value array.
309 if self._cache is _missing:
310 self._cache = b"".join(self._value)
311
312 del self._value
313
314 def set_none(self) -> None:
315 """Some fields in a querystring can possibly have a value of None - for
316 example, the string "foo&bar=&baz=asdf" will have a field with the
317 name "foo" and value None, one with name "bar" and value "", and one
318 with name "baz" and value "asdf". Since the write() interface doesn't
319 support writing None, this function will set the field value to None.
320 """
321 self._cache = None
322
323 @property
324 def field_name(self) -> bytes | None:
325 """This property returns the name of the field."""
326 return self._name
327
328 @property
329 def value(self) -> bytes | None:
330 """This property returns the value of the form field."""
331 if self._cache is _missing:
332 self._cache = b"".join(self._value)
333
334 assert isinstance(self._cache, bytes) or self._cache is None
335 return self._cache
336
337 @property
338 def content_type(self) -> str | None:
339 """This property returns the content_type value of the field."""
340 return self._content_type
341
342 def __eq__(self, other: object) -> bool:
343 if isinstance(other, Field):
344 return self.field_name == other.field_name and self.value == other.value
345 else:
346 return NotImplemented
347
348 def __repr__(self) -> str:
349 if self.value is not None and len(self.value) > 97:
350 # We get the repr, and then insert three dots before the final
351 # quote.
352 v = repr(self.value[:97])[:-1] + "...'"
353 else:
354 v = repr(self.value)
355
356 return f"{self.__class__.__name__}(field_name={self.field_name!r}, value={v})"
357
358
359class File:
360 """This class represents an uploaded file. It handles writing file data to
361 either an in-memory file or a temporary file on-disk, if the optional
362 threshold is passed.
363
364 There are some options that can be passed to the File to change behavior
365 of the class. Valid options are as follows:
366
367 | Name | Type | Default | Description |
368 |-----------------------|-------|---------|-------------|
369 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
370 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
371 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
372 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
373 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
374
375 Args:
376 file_name: The name of the file that this [`File`][python_multipart.File] represents.
377 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
378 the file was uploaded with Content-Type application/octet-stream.
379 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
380 content_type: The value of the Content-Type header.
381 """ # noqa: E501
382
383 def __init__(
384 self,
385 file_name: bytes | None,
386 field_name: bytes | None = None,
387 config: FileConfig = {},
388 *,
389 content_type: str | None = None,
390 ) -> None:
391 # Save configuration, set other variables default.
392 self.logger = logging.getLogger(__name__)
393 self._config = config
394 self._in_memory = True
395 self._bytes_written = 0
396 self._fileobj: BytesIO | BufferedRandom = BytesIO()
397
398 # Save the provided field/file name and content type.
399 self._field_name = field_name
400 self._file_name = file_name
401 self._content_type = content_type
402
403 # Our actual file name is None by default, since, depending on our
404 # config, we may not actually use the provided name.
405 self._actual_file_name: bytes | None = None
406
407 # Split the extension from the filename.
408 if file_name is not None:
409 # Extract just the basename to avoid directory traversal
410 basename = os.path.basename(file_name)
411 base, ext = os.path.splitext(basename)
412 self._file_base = base
413 self._ext = ext
414
415 @property
416 def field_name(self) -> bytes | None:
417 """The form field associated with this file. May be None if there isn't
418 one, for example when we have an application/octet-stream upload.
419 """
420 return self._field_name
421
422 @property
423 def file_name(self) -> bytes | None:
424 """The file name given in the upload request."""
425 return self._file_name
426
427 @property
428 def actual_file_name(self) -> bytes | None:
429 """The file name that this file is saved as. Will be None if it's not
430 currently saved on disk.
431 """
432 return self._actual_file_name
433
434 @property
435 def file_object(self) -> BytesIO | BufferedRandom:
436 """The file object that we're currently writing to. Note that this
437 will either be an instance of a :class:`io.BytesIO`, or a regular file
438 object.
439 """
440 return self._fileobj
441
442 @property
443 def size(self) -> int:
444 """The total size of this file, counted as the number of bytes that
445 currently have been written to the file.
446 """
447 return self._bytes_written
448
449 @property
450 def in_memory(self) -> bool:
451 """A boolean representing whether or not this file object is currently
452 stored in-memory or on-disk.
453 """
454 return self._in_memory
455
456 @property
457 def content_type(self) -> str | None:
458 """The Content-Type value for this part, if it was set."""
459 return self._content_type
460
461 def flush_to_disk(self) -> None:
462 """If the file is already on-disk, do nothing. Otherwise, copy from
463 the in-memory buffer to a disk file, and then reassign our internal
464 file object to this new disk file.
465
466 Note that if you attempt to flush a file that is already on-disk, a
467 warning will be logged to this module's logger.
468 """
469 if not self._in_memory:
470 self.logger.warning("Trying to flush to disk when we're not in memory")
471 return
472
473 # Go back to the start of our file.
474 self._fileobj.seek(0)
475
476 # Open a new file.
477 new_file = self._get_disk_file()
478
479 # Copy the file objects.
480 shutil.copyfileobj(self._fileobj, new_file)
481
482 # Seek to the new position in our new file.
483 new_file.seek(self._bytes_written)
484
485 # Reassign the fileobject.
486 old_fileobj = self._fileobj
487 self._fileobj = new_file
488
489 # We're no longer in memory.
490 self._in_memory = False
491
492 # Close the old file object.
493 old_fileobj.close()
494
495 def _get_disk_file(self) -> BufferedRandom:
496 """This function is responsible for getting a file object on-disk for us."""
497 self.logger.info("Opening a file on disk")
498
499 file_dir = self._config.get("UPLOAD_DIR")
500 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
501 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
502 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
503 tmp_file: None | BufferedRandom = None
504
505 # If we have a directory and are to keep the filename...
506 if file_dir is not None and keep_filename:
507 self.logger.info("Saving with filename in: %r", file_dir)
508
509 # Build our filename.
510 # TODO: what happens if we don't have a filename?
511 fname = self._file_base + self._ext if keep_extensions else self._file_base
512
513 path = os.path.join(file_dir, fname) # type: ignore[arg-type]
514 try:
515 self.logger.info("Opening file: %r", path)
516 tmp_file = open(path, "w+b")
517 except OSError:
518 tmp_file = None
519
520 self.logger.exception("Error opening temporary file")
521 raise FileError("Error opening temporary file: %r" % path)
522 else:
523 # Build options array.
524 # Note that on Python 3, tempfile doesn't support byte names. We
525 # encode our paths using the default filesystem encoding.
526 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
527
528 if file_dir is None:
529 dir = None
530 elif isinstance(file_dir, bytes):
531 dir = file_dir.decode(sys.getfilesystemencoding())
532 else:
533 dir = file_dir # pragma: no cover
534
535 # Create a temporary (named) file with the appropriate settings.
536 self.logger.info(
537 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
538 )
539 try:
540 tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
541 except OSError:
542 self.logger.exception("Error creating named temporary file")
543 raise FileError("Error creating named temporary file")
544
545 assert tmp_file is not None
546 # Encode filename as bytes.
547 if isinstance(tmp_file.name, str):
548 fname = tmp_file.name.encode(sys.getfilesystemencoding())
549 else:
550 fname = cast(bytes, tmp_file.name) # pragma: no cover
551
552 self._actual_file_name = fname
553 return tmp_file
554
555 def write(self, data: bytes) -> int:
556 """Write some data to the File.
557
558 :param data: a bytestring
559 """
560 return self.on_data(data)
561
562 def on_data(self, data: bytes) -> int:
563 """This method is a callback that will be called whenever data is
564 written to the File.
565
566 Args:
567 data: The data to write to the file.
568
569 Returns:
570 The number of bytes written.
571 """
572 bwritten = self._fileobj.write(data)
573
574 # If the bytes written isn't the same as the length, just return.
575 if bwritten != len(data):
576 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
577 return bwritten
578
579 # Keep track of how many bytes we've written.
580 self._bytes_written += bwritten
581
582 # If we're in-memory and are over our limit, we create a file.
583 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
584 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
585 self.logger.info("Flushing to disk")
586 self.flush_to_disk()
587
588 # Return the number of bytes written.
589 return bwritten
590
591 def on_end(self) -> None:
592 """This method is called whenever the Field is finalized."""
593 # Flush the underlying file object
594 self._fileobj.flush()
595
596 def finalize(self) -> None:
597 """Finalize the form file. This will not close the underlying file,
598 but simply signal that we are finished writing to the File.
599 """
600 self.on_end()
601
602 def close(self) -> None:
603 """Close the File object. This will actually close the underlying
604 file object (whether it's a :class:`io.BytesIO` or an actual file
605 object).
606 """
607 self._fileobj.close()
608
609 def __repr__(self) -> str:
610 return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})"
611
612
613class BaseParser:
614 """This class is the base class for all parsers. It contains the logic for
615 calling and adding callbacks.
616
617 A callback can be one of two different forms. "Notification callbacks" are
618 callbacks that are called when something happens - for example, when a new
619 part of a multipart message is encountered by the parser. "Data callbacks"
620 are called when we get some sort of data - for example, part of the body of
621 a multipart chunk. Notification callbacks are called with no parameters,
622 whereas data callbacks are called with three, as follows::
623
624 data_callback(data, start, end)
625
626 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
627 Python 3). "start" and "end" are integer indexes into the "data" string
628 that represent the data of interest. Thus, in a data callback, the slice
629 `data[start:end]` represents the data that the callback is "interested in".
630 The callback is not passed a copy of the data, since copying severely hurts
631 performance.
632 """
633
634 def __init__(self) -> None:
635 self.logger = logging.getLogger(__name__)
636 self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
637
638 def callback(
639 self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
640 ) -> None:
641 """This function calls a provided callback with some data. If the
642 callback is not set, will do nothing.
643
644 Args:
645 name: The name of the callback to call (as a string).
646 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
647 callback, and no parameters are given.
648 end: An integer that is passed to the data callback.
649 start: An integer that is passed to the data callback.
650 """
651 func = self.callbacks.get("on_" + name)
652 if func is None:
653 return
654 func = cast("Callable[..., Any]", func)
655 # Depending on whether we're given a buffer...
656 if data is not None:
657 # Don't do anything if we have start == end.
658 if start is not None and start == end:
659 return
660 func(data, start, end)
661 else:
662 func()
663
664 def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
665 """Update the function for a callback. Removes from the callbacks dict
666 if new_func is None.
667
668 :param name: The name of the callback to call (as a string).
669
670 :param new_func: The new function for the callback. If None, then the
671 callback will be removed (with no error if it does not
672 exist).
673 """
674 if new_func is None:
675 self.callbacks.pop("on_" + name, None) # type: ignore[misc]
676 else:
677 self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
678
679 def close(self) -> None:
680 pass # pragma: no cover
681
682 def finalize(self) -> None:
683 pass # pragma: no cover
684
685 def __repr__(self) -> str:
686 return "%s()" % self.__class__.__name__
687
688
689class OctetStreamParser(BaseParser):
690 """This parser parses an octet-stream request body and calls callbacks when
691 incoming data is received. Callbacks are as follows:
692
693 | Callback Name | Parameters | Description |
694 |----------------|-----------------|-----------------------------------------------------|
695 | on_start | None | Called when the first data is parsed. |
696 | on_data | data, start, end| Called for each data chunk that is parsed. |
697 | on_end | None | Called when the parser is finished parsing all data.|
698
699 Args:
700 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
701 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
702 """
703
704 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
705 super().__init__()
706 self.callbacks = callbacks
707 self._started = False
708
709 if not isinstance(max_size, Number) or max_size < 1:
710 raise ValueError("max_size must be a positive number, not %r" % max_size)
711 self.max_size: int | float = max_size
712 self._current_size = 0
713
714 def write(self, data: bytes) -> int:
715 """Write some data to the parser, which will perform size verification,
716 and then pass the data to the underlying callback.
717
718 Args:
719 data: The data to write to the parser.
720
721 Returns:
722 The number of bytes written.
723 """
724 if not self._started:
725 self.callback("start")
726 self._started = True
727
728 # Truncate data length.
729 data_len = len(data)
730 if (self._current_size + data_len) > self.max_size:
731 # We truncate the length of data that we are to process.
732 new_size = int(self.max_size - self._current_size)
733 self.logger.warning(
734 "Current size is %d (max %d), so truncating data length from %d to %d",
735 self._current_size,
736 self.max_size,
737 data_len,
738 new_size,
739 )
740 data_len = new_size
741
742 # Increment size, then callback, in case there's an exception.
743 self._current_size += data_len
744 self.callback("data", data, 0, data_len)
745 return data_len
746
747 def finalize(self) -> None:
748 """Finalize this parser, which signals to that we are finished parsing,
749 and sends the on_end callback.
750 """
751 self.callback("end")
752
753 def __repr__(self) -> str:
754 return "%s()" % self.__class__.__name__
755
756
757class QuerystringParser(BaseParser):
758 """This is a streaming querystring parser. It will consume data, and call
759 the callbacks given when it has data.
760
761 | Callback Name | Parameters | Description |
762 |----------------|-----------------|-----------------------------------------------------|
763 | on_field_start | None | Called when a new field is encountered. |
764 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
765 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
766 | on_field_end | None | Called when the end of a field is encountered. |
767 | on_end | None | Called when the parser is finished parsing all data.|
768
769 Args:
770 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
771 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
772 behavior of the parser changes as the following: if a field has a value with an equal sign
773 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
774 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
775 then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
776 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
777 """ # noqa: E501
778
779 state: QuerystringState
780
781 def __init__(
782 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
783 ) -> None:
784 super().__init__()
785 self.state = QuerystringState.BEFORE_FIELD
786 self._found_sep = False
787
788 self.callbacks = callbacks
789
790 # Max-size stuff
791 if not isinstance(max_size, Number) or max_size < 1:
792 raise ValueError("max_size must be a positive number, not %r" % max_size)
793 self.max_size: int | float = max_size
794 self._current_size = 0
795
796 # Should parsing be strict?
797 self.strict_parsing = strict_parsing
798
799 def write(self, data: bytes) -> int:
800 """Write some data to the parser, which will perform size verification,
801 parse into either a field name or value, and then pass the
802 corresponding data to the underlying callback. If an error is
803 encountered while parsing, a QuerystringParseError will be raised. The
804 "offset" attribute of the raised exception will be set to the offset in
805 the input data chunk (NOT the overall stream) that caused the error.
806
807 Args:
808 data: The data to write to the parser.
809
810 Returns:
811 The number of bytes written.
812 """
813 # Handle sizing.
814 data_len = len(data)
815 if (self._current_size + data_len) > self.max_size:
816 # We truncate the length of data that we are to process.
817 new_size = int(self.max_size - self._current_size)
818 self.logger.warning(
819 "Current size is %d (max %d), so truncating data length from %d to %d",
820 self._current_size,
821 self.max_size,
822 data_len,
823 new_size,
824 )
825 data_len = new_size
826
827 l = 0
828 try:
829 l = self._internal_write(data, data_len)
830 finally:
831 self._current_size += l
832
833 return l
834
835 def _internal_write(self, data: bytes, length: int) -> int:
836 state = self.state
837 strict_parsing = self.strict_parsing
838 found_sep = self._found_sep
839
840 i = 0
841 while i < length:
842 ch = data[i]
843
844 # Depending on our state...
845 if state == QuerystringState.BEFORE_FIELD:
846 # If the 'found_sep' flag is set, we've already encountered
847 # and skipped a single separator. If so, we check our strict
848 # parsing flag and decide what to do. Otherwise, we haven't
849 # yet reached a separator, and thus, if we do, we need to skip
850 # it as it will be the boundary between fields that's supposed
851 # to be there.
852 if ch == AMPERSAND:
853 if found_sep:
854 # If we're parsing strictly, we disallow blank chunks.
855 if strict_parsing:
856 raise QuerystringParseError("Skipping duplicate ampersand at %d" % i, offset=i)
857 else:
858 self.logger.debug("Skipping duplicate ampersand at %d", i)
859 else:
860 # This case is when we're skipping the (first)
861 # separator between fields, so we just set our flag
862 # and continue on.
863 found_sep = True
864 else:
865 # Emit a field-start event, and go to that state. Also,
866 # reset the "found_sep" flag, for the next time we get to
867 # this state.
868 self.callback("field_start")
869 i -= 1
870 state = QuerystringState.FIELD_NAME
871 found_sep = False
872
873 elif state == QuerystringState.FIELD_NAME:
874 # Try and find a separator - we ensure that, if we do, we only
875 # look for the equal sign before it.
876 sep_pos = data.find(b"&", i, length)
877
878 # See if we can find an equals sign in the remaining data. If
879 # so, we can immediately emit the field name and jump to the
880 # data state.
881 if sep_pos != -1:
882 equals_pos = data.find(b"=", i, sep_pos)
883 else:
884 equals_pos = data.find(b"=", i, length)
885
886 if equals_pos != -1:
887 # Emit this name.
888 self.callback("field_name", data, i, equals_pos)
889
890 # Jump i to this position. Note that it will then have 1
891 # added to it below, which means the next iteration of this
892 # loop will inspect the character after the equals sign.
893 i = equals_pos
894 state = QuerystringState.FIELD_DATA
895 else:
896 # No equals sign found.
897 if not strict_parsing:
898 # See also comments in the QuerystringState.FIELD_DATA case below.
899 # If we found the separator, we emit the name and just
900 # end - there's no data callback at all (not even with
901 # a blank value).
902 if sep_pos != -1:
903 self.callback("field_name", data, i, sep_pos)
904 self.callback("field_end")
905
906 i = sep_pos - 1
907 state = QuerystringState.BEFORE_FIELD
908 else:
909 # Otherwise, no separator in this block, so the
910 # rest of this chunk must be a name.
911 self.callback("field_name", data, i, length)
912 i = length
913
914 else:
915 # We're parsing strictly. If we find a separator,
916 # this is an error - we require an equals sign.
917 if sep_pos != -1:
918 raise QuerystringParseError(
919 "When strict_parsing is True, we require an "
920 "equals sign in all field chunks. Did not "
921 "find one in the chunk that starts at %d" % (i,),
922 offset=i,
923 )
924
925 # No separator in the rest of this chunk, so it's just
926 # a field name.
927 self.callback("field_name", data, i, length)
928 i = length
929
930 elif state == QuerystringState.FIELD_DATA:
931 # Try finding an ampersand after this position.
932 sep_pos = data.find(b"&", i, length)
933
934 # If we found it, callback this bit as data and then go back
935 # to expecting to find a field.
936 if sep_pos != -1:
937 self.callback("field_data", data, i, sep_pos)
938 self.callback("field_end")
939
940 # Note that we go to the separator, which brings us to the
941 # "before field" state. This allows us to properly emit
942 # "field_start" events only when we actually have data for
943 # a field of some sort.
944 i = sep_pos - 1
945 state = QuerystringState.BEFORE_FIELD
946
947 # Otherwise, emit the rest as data and finish.
948 else:
949 self.callback("field_data", data, i, length)
950 i = length
951
952 else: # pragma: no cover (error case)
953 msg = "Reached an unknown state %d at %d" % (state, i)
954 self.logger.warning(msg)
955 raise QuerystringParseError(msg, offset=i)
956
957 i += 1
958
959 self.state = state
960 self._found_sep = found_sep
961 return length
962
963 def finalize(self) -> None:
964 """Finalize this parser, which signals to that we are finished parsing,
965 if we're still in the middle of a field, an on_field_end callback, and
966 then the on_end callback.
967 """
968 # If we're currently in the middle of a field, we finish it.
969 if self.state in (QuerystringState.FIELD_DATA, QuerystringState.FIELD_NAME):
970 self.callback("field_end")
971 self.callback("end")
972
973 def __repr__(self) -> str:
974 return "{}(strict_parsing={!r}, max_size={!r})".format(
975 self.__class__.__name__, self.strict_parsing, self.max_size
976 )
977
978
979class MultipartParser(BaseParser):
980 """This class is a streaming multipart/form-data parser.
981
982 | Callback Name | Parameters | Description |
983 |--------------------|-----------------|-------------|
984 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
985 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
986 | on_part_end | None | Called when the end of a part is reached. |
987 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
988 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
989 | on_header_value | data, start, end| Called when we get data for a header. |
990 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
991 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
992 | on_end | None | Called when the parser is finished parsing all data. |
993
994 Args:
995 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
996 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
997 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
998 max_header_count: The maximum number of headers allowed per part.
999 max_header_size: The maximum size of a single header line (excluding the trailing CRLF).
1000 """ # noqa: E501
1001
1002 def __init__(
1003 self,
1004 boundary: bytes | str,
1005 callbacks: MultipartCallbacks = {},
1006 max_size: float = float("inf"),
1007 *,
1008 max_header_count: int = DEFAULT_MAX_HEADER_COUNT,
1009 max_header_size: int = DEFAULT_MAX_HEADER_SIZE,
1010 ) -> None:
1011 # Initialize parser state.
1012 super().__init__()
1013 self.state = MultipartState.START
1014 self.index = self.flags = 0
1015
1016 self.callbacks = callbacks
1017
1018 if not isinstance(max_size, Number) or max_size < 1:
1019 raise ValueError("max_size must be a positive number, not %r" % max_size)
1020 self.max_size = max_size
1021 self._current_size = 0
1022
1023 self.max_header_count = max_header_count
1024 self._current_header_count = 0
1025
1026 self.max_header_size = max_header_size
1027 self._current_header_size = 0
1028
1029 # Setup marks. These are used to track the state of data received.
1030 self.marks: dict[str, int] = {}
1031
1032 # Save our boundary.
1033 if isinstance(boundary, str): # pragma: no cover
1034 boundary = boundary.encode("latin-1")
1035 if len(boundary) > MAX_BOUNDARY_LENGTH:
1036 raise FormParserError(f"Boundary length {len(boundary)} exceeds maximum of {MAX_BOUNDARY_LENGTH}")
1037 self.boundary = b"\r\n--" + boundary
1038
1039 def write(self, data: bytes) -> int:
1040 """Write some data to the parser, which will perform size verification,
1041 and then parse the data into the appropriate location (e.g. header,
1042 data, etc.), and pass this on to the underlying callback. If an error
1043 is encountered, a MultipartParseError will be raised. The "offset"
1044 attribute on the raised exception will be set to the offset of the byte
1045 in the input chunk that caused the error.
1046
1047 Args:
1048 data: The data to write to the parser.
1049
1050 Returns:
1051 The number of bytes written.
1052 """
1053 # Handle sizing.
1054 data_len = len(data)
1055 if (self._current_size + data_len) > self.max_size:
1056 # We truncate the length of data that we are to process.
1057 new_size = int(self.max_size - self._current_size)
1058 self.logger.warning(
1059 "Current size is %d (max %d), so truncating data length from %d to %d",
1060 self._current_size,
1061 self.max_size,
1062 data_len,
1063 new_size,
1064 )
1065 data_len = new_size
1066
1067 l = 0
1068 try:
1069 l = self._internal_write(data, data_len)
1070 finally:
1071 self._current_size += l
1072
1073 return l
1074
1075 def _internal_write(self, data: bytes, length: int) -> int:
1076 # Get values from locals.
1077 boundary = self.boundary
1078 boundary_length = len(boundary)
1079
1080 # Get our state, flags and index. These are persisted between calls to
1081 # this function.
1082 state = self.state
1083 index = self.index
1084 flags = self.flags
1085 current_header_count = self._current_header_count
1086 current_header_size = self._current_header_size
1087
1088 # Our index defaults to 0.
1089 i = 0
1090
1091 def advance_header_size(amount: int = 1) -> None:
1092 nonlocal current_header_size
1093 current_header_size += amount
1094 if current_header_size > self.max_header_size:
1095 raise MultipartParseError("Maximum header size exceeded", offset=i)
1096
1097 # Set a mark.
1098 def set_mark(name: str) -> None:
1099 self.marks[name] = i
1100
1101 # Remove a mark.
1102 def delete_mark(name: str, reset: bool = False) -> None:
1103 self.marks.pop(name, None)
1104
1105 # Helper function that makes calling a callback with data easier. The
1106 # 'remaining' parameter will callback from the marked value until the
1107 # end of the buffer, and reset the mark, instead of deleting it. This
1108 # is used at the end of the function to call our callbacks with any
1109 # remaining data in this chunk.
1110 def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
1111 marked_index = self.marks.get(name)
1112 if marked_index is None:
1113 return
1114
1115 # Otherwise, we call it from the mark to the current byte we're
1116 # processing.
1117 if end_i <= marked_index:
1118 # There is no additional data to send.
1119 pass
1120 elif marked_index >= 0:
1121 # We are emitting data from the local buffer.
1122 self.callback(name, data, marked_index, end_i)
1123 else:
1124 # Some of the data comes from a partial boundary match.
1125 # and requires look-behind.
1126 # We need to use self.flags (and not flags) because we care about
1127 # the state when we entered the loop.
1128 lookbehind_len = -marked_index
1129 if lookbehind_len <= boundary_length:
1130 self.callback(name, boundary, 0, lookbehind_len)
1131 elif self.flags & FLAG_PART_BOUNDARY:
1132 lookback = boundary + b"\r\n"
1133 self.callback(name, lookback, 0, lookbehind_len)
1134 elif self.flags & FLAG_LAST_BOUNDARY:
1135 lookback = boundary + b"--\r\n"
1136 self.callback(name, lookback, 0, lookbehind_len)
1137 else: # pragma: no cover (error case)
1138 self.logger.warning("Look-back buffer error")
1139
1140 if end_i > 0:
1141 self.callback(name, data, 0, end_i)
1142 # If we're getting remaining data, we have got all the data we
1143 # can be certain is not a boundary, leaving only a partial boundary match.
1144 if remaining:
1145 self.marks[name] = end_i - length
1146 else:
1147 self.marks.pop(name, None)
1148
1149 # For each byte...
1150 while i < length:
1151 c = data[i]
1152
1153 if state == MultipartState.START:
1154 # Skip leading newlines
1155 if c == CR or c == LF:
1156 i = data.find(b"-", i)
1157 if i == -1:
1158 # No boundary candidate in this chunk, so ignore the content after the leading CR/LF.
1159 i = length
1160 break
1161 continue
1162
1163 # index is used as in index into our boundary. Set to 0.
1164 index = 0
1165
1166 # Move to the next state, but decrement i so that we re-process
1167 # this character.
1168 state = MultipartState.START_BOUNDARY
1169 i -= 1
1170
1171 elif state == MultipartState.START_BOUNDARY:
1172 # Check to ensure that the last 2 characters in our boundary
1173 # are CRLF.
1174 if index == boundary_length - 2:
1175 if c == HYPHEN:
1176 # Potential empty message.
1177 state = MultipartState.END_BOUNDARY
1178 elif c != CR:
1179 # Error!
1180 msg = "Did not find CR at end of boundary (%d)" % (i,)
1181 self.logger.warning(msg)
1182 raise MultipartParseError(msg, offset=i)
1183
1184 index += 1
1185
1186 elif index == boundary_length - 1:
1187 if c != LF:
1188 msg = "Did not find LF at end of boundary (%d)" % (i,)
1189 self.logger.warning(msg)
1190 raise MultipartParseError(msg, offset=i)
1191
1192 # The index is now used for indexing into our boundary.
1193 index = 0
1194
1195 # Callback for the start of a part.
1196 self.callback("part_begin")
1197 current_header_count = 0
1198 current_header_size = 0
1199
1200 # Move to the next character and state.
1201 state = MultipartState.HEADER_FIELD_START
1202
1203 else:
1204 # Check to ensure our boundary matches
1205 if c != boundary[index + 2]:
1206 msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
1207 self.logger.warning(msg)
1208 raise MultipartParseError(msg, offset=i)
1209
1210 # Increment index into boundary and continue.
1211 index += 1
1212
1213 elif state == MultipartState.HEADER_FIELD_START:
1214 # Mark the start of a header field here, reset the index, and
1215 # continue parsing our header field.
1216 index = 0
1217
1218 if c != CR:
1219 current_header_count += 1
1220 if current_header_count > self.max_header_count:
1221 raise MultipartParseError("Maximum header count exceeded", offset=i)
1222 current_header_size = 0
1223
1224 # Set a mark of our header field.
1225 set_mark("header_field")
1226
1227 # Notify that we're starting a header if the next character is
1228 # not a CR; a CR at the beginning of the header will cause us
1229 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1230 # below.
1231 if c != CR:
1232 self.callback("header_begin")
1233
1234 # Move to parsing header fields.
1235 state = MultipartState.HEADER_FIELD
1236 i -= 1
1237
1238 elif state == MultipartState.HEADER_FIELD:
1239 # If we've reached a CR at the beginning of a header, it means
1240 # that we've reached the second of 2 newlines, and so there are
1241 # no more headers to parse.
1242 if c == CR and index == 0:
1243 delete_mark("header_field")
1244 state = MultipartState.HEADERS_ALMOST_DONE
1245 i += 1
1246 continue
1247
1248 # The field name runs until the colon; jump straight to it and
1249 # validate the whole span at once instead of byte by byte.
1250 colon = data.find(b":", i, length)
1251 end = colon if colon != -1 else length
1252
1253 # Enforce the size limit before slicing and validating, so an oversized header
1254 # name fails fast instead of copying and scanning a potentially huge span.
1255 advance_header_size(end - i if colon == -1 else end - i + 1)
1256
1257 field = data[i:end]
1258 if field.translate(None, TOKEN_CHARS):
1259 bad = next(b for b in field if b not in TOKEN_CHARS_SET)
1260 bad_i = i + field.index(bad)
1261 msg = "Found invalid character %r in header at %d" % (bad, bad_i)
1262 self.logger.warning(msg)
1263 raise MultipartParseError(msg, offset=bad_i)
1264
1265 index += end - i
1266 if colon == -1:
1267 # Field name continues into the next chunk.
1268 i = length
1269 else:
1270 # A 0-length header is an error.
1271 if index == 0:
1272 msg = "Found 0-length header at %d" % (i,)
1273 self.logger.warning(msg)
1274 raise MultipartParseError(msg, offset=i)
1275
1276 # Call our callback with the header field.
1277 i = colon
1278 data_callback("header_field", i)
1279
1280 # Move to parsing the header value.
1281 state = MultipartState.HEADER_VALUE_START
1282
1283 elif state == MultipartState.HEADER_VALUE_START:
1284 # Skip leading spaces.
1285 if c == SPACE:
1286 advance_header_size()
1287 i += 1
1288 continue
1289
1290 # Mark the start of the header value.
1291 set_mark("header_value")
1292
1293 # Move to the header-value state, reprocessing this character.
1294 state = MultipartState.HEADER_VALUE
1295 i -= 1
1296
1297 elif state == MultipartState.HEADER_VALUE:
1298 # The value runs until the terminating CR; jump straight to it
1299 # instead of inspecting every byte.
1300 cr = data.find(b"\r", i, length)
1301 end = cr if cr != -1 else length
1302 advance_header_size(end - i)
1303 if cr != -1:
1304 i = cr
1305 data_callback("header_value", i)
1306 self.callback("header_end")
1307 current_header_size = 0
1308 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1309 else:
1310 i = length
1311
1312 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1313 # The last character should be a LF. If not, it's an error.
1314 if c != LF:
1315 msg = f"Did not find LF character at end of header (found {c!r})"
1316 self.logger.warning(msg)
1317 raise MultipartParseError(msg, offset=i)
1318
1319 # Move back to the start of another header. Note that if that
1320 # state detects ANOTHER newline, it'll trigger the end of our
1321 # headers.
1322 state = MultipartState.HEADER_FIELD_START
1323
1324 elif state == MultipartState.HEADERS_ALMOST_DONE:
1325 # We're almost done our headers. This is reached when we parse
1326 # a CR at the beginning of a header, so our next character
1327 # should be a LF, or it's an error.
1328 if c != LF:
1329 msg = f"Did not find LF at end of headers (found {c!r})"
1330 self.logger.warning(msg)
1331 raise MultipartParseError(msg, offset=i)
1332
1333 self.callback("headers_finished")
1334 state = MultipartState.PART_DATA_START
1335
1336 elif state == MultipartState.PART_DATA_START:
1337 # Mark the start of our part data.
1338 set_mark("part_data")
1339
1340 # Start processing part data, including this character.
1341 state = MultipartState.PART_DATA
1342 i -= 1
1343
1344 elif state == MultipartState.PART_DATA:
1345 # We're processing our part data right now. During this, we
1346 # need to efficiently search for our boundary, since any data
1347 # on any number of lines can be a part of the current data.
1348
1349 # Save the current value of our index. We use this in case we
1350 # find part of a boundary, but it doesn't match fully.
1351 prev_index = index
1352
1353 # If our index is 0, we're starting a new part, so start our
1354 # search.
1355 if index == 0:
1356 # The most common case is likely to be that the whole
1357 # boundary is present in the buffer.
1358 # Calling `find` is much faster than iterating here.
1359 i0 = data.find(boundary, i, length)
1360 if i0 >= 0:
1361 # We matched the whole boundary string.
1362 index = boundary_length - 1
1363 i = i0 + boundary_length - 1
1364 c = data[i]
1365 else:
1366 # No whole boundary, but the tail may hold a partial one
1367 # that completes in the next chunk. Boundary starts with
1368 # CR, which an RFC boundary contains nowhere else, so the
1369 # last CR in the tail is the only candidate prefix start.
1370 k = data.rfind(boundary[:1], max(i, length - boundary_length + 1), length)
1371 if k != -1 and boundary.startswith(data[k:length]):
1372 index = length - k
1373 # Carry the partial via index; the end-of-chunk flush
1374 # emits the data before it and re-marks the lookbehind.
1375 i = length
1376 continue
1377
1378 # Now, we have a couple of cases here. If our index is before
1379 # the end of the boundary...
1380 if index < boundary_length:
1381 # If the character matches...
1382 if boundary[index] == c:
1383 # The current character matches, so continue!
1384 index += 1
1385 else:
1386 index = 0
1387
1388 # Our index is equal to the length of our boundary!
1389 elif index == boundary_length:
1390 # First we increment it.
1391 index += 1
1392
1393 # Now, if we've reached a newline, we need to set this as
1394 # the potential end of our boundary.
1395 if c == CR:
1396 flags |= FLAG_PART_BOUNDARY
1397
1398 # Otherwise, if this is a hyphen, we might be at the last
1399 # of all boundaries.
1400 elif c == HYPHEN:
1401 flags |= FLAG_LAST_BOUNDARY
1402
1403 # Otherwise, we reset our index, since this isn't either a
1404 # newline or a hyphen.
1405 else:
1406 index = 0
1407
1408 # Our index is right after the part boundary, which should be
1409 # a LF.
1410 elif index == boundary_length + 1:
1411 # If we're at a part boundary (i.e. we've seen a CR
1412 # character already)...
1413 if flags & FLAG_PART_BOUNDARY:
1414 # We need a LF character next.
1415 if c == LF:
1416 # Unset the part boundary flag.
1417 flags &= ~FLAG_PART_BOUNDARY
1418
1419 # We have identified a boundary, callback for any data before it.
1420 data_callback("part_data", i - index)
1421 # Callback indicating that we've reached the end of
1422 # a part, and are starting a new one.
1423 self.callback("part_end")
1424 self.callback("part_begin")
1425 current_header_count = 0
1426 current_header_size = 0
1427
1428 # Move to parsing new headers.
1429 index = 0
1430 state = MultipartState.HEADER_FIELD_START
1431 i += 1
1432 continue
1433
1434 # We didn't find an LF character, so no match. Reset
1435 # our index and clear our flag.
1436 index = 0
1437 flags &= ~FLAG_PART_BOUNDARY
1438
1439 # Otherwise, if we're at the last boundary (i.e. we've
1440 # seen a hyphen already)...
1441 elif flags & FLAG_LAST_BOUNDARY:
1442 # We need a second hyphen here.
1443 if c == HYPHEN:
1444 # We have identified a boundary, callback for any data before it.
1445 data_callback("part_data", i - index)
1446 # Callback to end the current part, and then the
1447 # message.
1448 self.callback("part_end")
1449 self.callback("end")
1450 state = MultipartState.END
1451 else:
1452 # No match, so reset index.
1453 index = 0
1454
1455 # Otherwise, our index is 0. If the previous index is not, it
1456 # means we reset something, and we need to take the data we
1457 # thought was part of our boundary and send it along as actual
1458 # data.
1459 if index == 0 and prev_index > 0:
1460 # Overwrite our previous index.
1461 prev_index = 0
1462
1463 # Re-consider the current character, since this could be
1464 # the start of the boundary itself.
1465 i -= 1
1466
1467 elif state == MultipartState.END_BOUNDARY:
1468 if index == boundary_length - 1:
1469 if c != HYPHEN:
1470 msg = "Did not find - at end of boundary (%d)" % (i,)
1471 self.logger.warning(msg)
1472 raise MultipartParseError(msg, offset=i)
1473 index += 1
1474 self.callback("end")
1475 state = MultipartState.END
1476
1477 elif state == MultipartState.END:
1478 # Silently discard any epilogue data (RFC 2046 section 5.1.1 allows a CRLF and optional
1479 # epilogue after the closing boundary). Django and Werkzeug do the same.
1480 i = length
1481 break
1482
1483 else: # pragma: no cover (error case)
1484 # We got into a strange state somehow! Just stop processing.
1485 msg = "Reached an unknown state %d at %d" % (state, i)
1486 self.logger.warning(msg)
1487 raise MultipartParseError(msg, offset=i)
1488
1489 # Move to the next byte.
1490 i += 1
1491
1492 # We call our callbacks with any remaining data. Note that we pass
1493 # the 'remaining' flag, which sets the mark back to 0 instead of
1494 # deleting it, if it's found. This is because, if the mark is found
1495 # at this point, we assume that there's data for one of these things
1496 # that has been parsed, but not yet emitted. And, as such, it implies
1497 # that we haven't yet reached the end of this 'thing'. So, by setting
1498 # the mark to 0, we cause any data callbacks that take place in future
1499 # calls to this function to start from the beginning of that buffer.
1500 data_callback("header_field", length, True)
1501 data_callback("header_value", length, True)
1502 data_callback("part_data", length - index, True)
1503
1504 # Save values to locals.
1505 self.state = state
1506 self.index = index
1507 self.flags = flags
1508 self._current_header_count = current_header_count
1509 self._current_header_size = current_header_size
1510
1511 # Return our data length to indicate no errors, and that we processed
1512 # all of it.
1513 return length
1514
1515 def finalize(self) -> None:
1516 """Finalize this parser, which signals to that we are finished parsing.
1517
1518 Note: It does not currently, but in the future, it will verify that we
1519 are in the final state of the parser (i.e. the end of the multipart
1520 message is well-formed), and, if not, throw an error.
1521 """
1522 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1523 # error or otherwise state that we're not finished parsing.
1524 pass
1525
1526 def __repr__(self) -> str:
1527 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1528
1529
1530class FormParser:
1531 """This class is the all-in-one form parser. Given all the information
1532 necessary to parse a form, it will instantiate the correct parser, create
1533 the proper :class:`Field` and :class:`File` classes to store the data that
1534 is parsed, and call the two given callbacks with each field and file as
1535 they become available.
1536
1537 Args:
1538 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1539 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1540 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1541 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1542 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1543 in the Content-Type header, as a bytestring.
1544 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1545 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1546 manually.
1547 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1548 and then any keys present in this dictionary will overwrite the default values.
1549 """
1550
1551 #: This is the default configuration for our form parser.
1552 #: Note: all file sizes should be in bytes.
1553 DEFAULT_CONFIG: FormParserConfig = {
1554 "MAX_BODY_SIZE": float("inf"),
1555 "MAX_HEADER_COUNT": DEFAULT_MAX_HEADER_COUNT,
1556 "MAX_HEADER_SIZE": DEFAULT_MAX_HEADER_SIZE,
1557 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1558 "UPLOAD_DIR": None,
1559 "UPLOAD_DELETE_TMP": True,
1560 "UPLOAD_KEEP_FILENAME": False,
1561 "UPLOAD_KEEP_EXTENSIONS": False,
1562 # Error on invalid Content-Transfer-Encoding?
1563 "UPLOAD_ERROR_ON_BAD_CTE": False,
1564 }
1565
1566 def __init__(
1567 self,
1568 content_type: str,
1569 on_field: Callable[[Field], None] | None,
1570 on_file: Callable[[File], None] | None,
1571 on_end: Callable[[], None] | None = None,
1572 boundary: bytes | str | None = None,
1573 file_name: bytes | None = None,
1574 config: dict[Any, Any] = {},
1575 ) -> None:
1576 self.logger = logging.getLogger(__name__)
1577
1578 # Save variables.
1579 self.content_type = content_type
1580 self.boundary = boundary
1581 self.bytes_received = 0
1582 self.parser = None
1583
1584 # Save callbacks.
1585 self.on_field = on_field
1586 self.on_file = on_file
1587 self.on_end = on_end
1588
1589 # Set configuration options.
1590 self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
1591 self.config.update(config) # type: ignore[typeddict-item]
1592
1593 parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
1594
1595 # Depending on the Content-Type, we instantiate the correct parser.
1596 if content_type == "application/octet-stream":
1597 file: File | None = None
1598
1599 def on_start() -> None:
1600 nonlocal file
1601 file = File(file_name, None, config=self.config)
1602
1603 def on_data(data: bytes, start: int, end: int) -> None:
1604 nonlocal file
1605 assert file is not None
1606 file.write(data[start:end])
1607
1608 def _on_end() -> None:
1609 nonlocal file
1610 assert file is not None
1611 # Finalize the file itself.
1612 file.finalize()
1613
1614 # Call our callback.
1615 if on_file:
1616 on_file(file)
1617
1618 # Call the on-end callback.
1619 if self.on_end is not None:
1620 self.on_end()
1621
1622 # Instantiate an octet-stream parser
1623 parser = OctetStreamParser(
1624 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1625 max_size=self.config["MAX_BODY_SIZE"],
1626 )
1627
1628 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1629 name_buffer: list[bytes] = []
1630
1631 f: Field | None = None
1632
1633 def on_field_start() -> None:
1634 pass
1635
1636 def on_field_name(data: bytes, start: int, end: int) -> None:
1637 name_buffer.append(data[start:end])
1638
1639 def on_field_data(data: bytes, start: int, end: int) -> None:
1640 nonlocal f
1641 if f is None:
1642 f = Field(b"".join(name_buffer))
1643 del name_buffer[:]
1644 f.write(data[start:end])
1645
1646 def on_field_end() -> None:
1647 nonlocal f
1648 # Finalize and call callback.
1649 if f is None:
1650 # If we get here, it's because there was no field data.
1651 # We create a field, set it to None, and then continue.
1652 f = Field(b"".join(name_buffer))
1653 del name_buffer[:]
1654 f.set_none()
1655
1656 f.finalize()
1657 if on_field:
1658 on_field(f)
1659 f = None
1660
1661 def _on_end() -> None:
1662 if self.on_end is not None:
1663 self.on_end()
1664
1665 # Instantiate parser.
1666 parser = QuerystringParser(
1667 callbacks={
1668 "on_field_start": on_field_start,
1669 "on_field_name": on_field_name,
1670 "on_field_data": on_field_data,
1671 "on_field_end": on_field_end,
1672 "on_end": _on_end,
1673 },
1674 max_size=self.config["MAX_BODY_SIZE"],
1675 )
1676
1677 elif content_type == "multipart/form-data":
1678 if boundary is None:
1679 self.logger.error("No boundary given")
1680 raise FormParserError("No boundary given")
1681
1682 header_name: list[bytes] = []
1683 header_value: list[bytes] = []
1684 headers: dict[bytes, bytes] = {}
1685
1686 f_multi: File | Field | None = None
1687 writer: File | Field | Base64Decoder | QuotedPrintableDecoder | None = None
1688 is_file = False
1689
1690 def on_part_begin() -> None:
1691 # Reset headers in case this isn't the first part.
1692 nonlocal headers
1693 headers = {}
1694
1695 def on_part_data(data: bytes, start: int, end: int) -> None:
1696 nonlocal writer
1697 assert writer is not None
1698 writer.write(data[start:end])
1699 # TODO: check for error here.
1700
1701 def on_part_end() -> None:
1702 nonlocal f_multi, is_file
1703 assert f_multi is not None
1704 f_multi.finalize()
1705 if is_file:
1706 if on_file:
1707 assert isinstance(f_multi, File)
1708 on_file(f_multi)
1709 else:
1710 if on_field:
1711 assert isinstance(f_multi, Field)
1712 on_field(f_multi)
1713
1714 def on_header_field(data: bytes, start: int, end: int) -> None:
1715 header_name.append(data[start:end])
1716
1717 def on_header_value(data: bytes, start: int, end: int) -> None:
1718 header_value.append(data[start:end])
1719
1720 def on_header_end() -> None:
1721 headers[b"".join(header_name).lower()] = b"".join(header_value)
1722 del header_name[:]
1723 del header_value[:]
1724
1725 def on_headers_finished() -> None:
1726 nonlocal is_file, f_multi, writer
1727 # Reset the 'is file' flag.
1728 is_file = False
1729
1730 # Parse the content-disposition header.
1731 content_disp = headers.get(b"content-disposition")
1732 disp, options = parse_options_header(content_disp)
1733
1734 # Get the field and filename.
1735 field_name = options.get(b"name")
1736 file_name = options.get(b"filename")
1737 # RFC 7578 §4.2: each part MUST have a Content-Disposition header with a "name" parameter.
1738 if field_name is None:
1739 raise FormParserError(f'Field name not found in Content-Disposition: "{content_disp!r}"')
1740
1741 # Create the proper class.
1742 content_type_b = headers.get(b"content-type")
1743 content_type = content_type_b.decode("latin-1") if content_type_b is not None else None
1744 if file_name is None:
1745 f_multi = Field(field_name, content_type=content_type)
1746 else:
1747 f_multi = File(file_name, field_name, config=self.config, content_type=content_type)
1748 is_file = True
1749
1750 # Parse the given Content-Transfer-Encoding to determine what
1751 # we need to do with the incoming data.
1752 # TODO: check that we properly handle 8bit / 7bit encoding.
1753 # RFC 2045 section 6.1: Content-Transfer-Encoding values are case-insensitive.
1754 # https://www.rfc-editor.org/rfc/rfc2045#section-6.1
1755 transfer_encoding = headers.get(b"content-transfer-encoding", b"7bit").lower()
1756
1757 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1758 writer = f_multi
1759
1760 elif transfer_encoding == b"base64":
1761 writer = Base64Decoder(f_multi)
1762
1763 elif transfer_encoding == b"quoted-printable":
1764 writer = QuotedPrintableDecoder(f_multi)
1765
1766 else:
1767 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1768 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1769 raise FormParserError(f'Unknown Content-Transfer-Encoding "{transfer_encoding!r}"')
1770 else:
1771 # If we aren't erroring, then we just treat this as an
1772 # unencoded Content-Transfer-Encoding.
1773 writer = f_multi
1774
1775 def _on_end() -> None:
1776 nonlocal writer
1777 if writer is not None:
1778 writer.finalize()
1779 if self.on_end is not None:
1780 self.on_end()
1781
1782 # Instantiate a multipart parser.
1783 parser = MultipartParser(
1784 boundary,
1785 callbacks={
1786 "on_part_begin": on_part_begin,
1787 "on_part_data": on_part_data,
1788 "on_part_end": on_part_end,
1789 "on_header_field": on_header_field,
1790 "on_header_value": on_header_value,
1791 "on_header_end": on_header_end,
1792 "on_headers_finished": on_headers_finished,
1793 "on_end": _on_end,
1794 },
1795 max_size=self.config["MAX_BODY_SIZE"],
1796 max_header_count=self.config["MAX_HEADER_COUNT"],
1797 max_header_size=self.config["MAX_HEADER_SIZE"],
1798 )
1799
1800 else:
1801 self.logger.warning("Unknown Content-Type: %r", content_type)
1802 raise FormParserError(f"Unknown Content-Type: {content_type}")
1803
1804 self.parser = parser
1805
1806 def write(self, data: bytes) -> int:
1807 """Write some data. The parser will forward this to the appropriate
1808 underlying parser.
1809
1810 Args:
1811 data: The data to write.
1812
1813 Returns:
1814 The number of bytes processed.
1815 """
1816 self.bytes_received += len(data)
1817 # TODO: check the parser's return value for errors?
1818 assert self.parser is not None
1819 return self.parser.write(data)
1820
1821 def finalize(self) -> None:
1822 """Finalize the parser."""
1823 if self.parser is not None and hasattr(self.parser, "finalize"):
1824 self.parser.finalize()
1825
1826 def close(self) -> None:
1827 """Close the parser."""
1828 if self.parser is not None and hasattr(self.parser, "close"):
1829 self.parser.close()
1830
1831 def __repr__(self) -> str:
1832 return f"{self.__class__.__name__}(content_type={self.content_type!r}, parser={self.parser!r})"
1833
1834
1835def create_form_parser(
1836 headers: dict[str, bytes],
1837 on_field: Callable[[Field], None] | None,
1838 on_file: Callable[[File], None] | None,
1839 config: dict[Any, Any] = {},
1840) -> FormParser:
1841 """This function is a helper function to aid in creating a FormParser
1842 instances. Given a dictionary-like headers object, it will determine
1843 the correct information needed, instantiate a FormParser with the
1844 appropriate values and given callbacks, and then return the corresponding
1845 parser.
1846
1847 Args:
1848 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1849 on_field: Callback to call with each parsed field.
1850 on_file: Callback to call with each parsed file.
1851 config: Configuration variables to pass to the FormParser.
1852 """
1853 content_type: str | bytes | None = headers.get("Content-Type")
1854 if content_type is None:
1855 logging.getLogger(__name__).warning("No Content-Type header given")
1856 raise ValueError("No Content-Type header given!")
1857
1858 # Boundaries are optional (the FormParser will raise if one is needed
1859 # but not given).
1860 content_type, params = parse_options_header(content_type)
1861 boundary = params.get(b"boundary")
1862
1863 # We need content_type to be a string, not a bytes object.
1864 content_type = content_type.decode("latin-1")
1865
1866 # Instantiate a form parser.
1867 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, config=config)
1868
1869 # Return our parser.
1870 return form_parser
1871
1872
1873def parse_form(
1874 headers: dict[str, bytes],
1875 input_stream: SupportsRead,
1876 on_field: Callable[[Field], None] | None,
1877 on_file: Callable[[File], None] | None,
1878 chunk_size: int = 1048576,
1879) -> None:
1880 """This function is useful if you just want to parse a request body,
1881 without too much work. Pass it a dictionary-like object of the request's
1882 headers, and a file-like object for the input stream, along with two
1883 callbacks that will get called whenever a field or file is parsed.
1884
1885 Args:
1886 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1887 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1888 on_field: Callback to call with each parsed field.
1889 on_file: Callback to call with each parsed file.
1890 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1891 Defaults to 1 MiB.
1892 """
1893 if chunk_size < 1:
1894 raise ValueError(f"chunk_size must be a positive number, not {chunk_size!r}")
1895
1896 # Create our form parser.
1897 parser = create_form_parser(headers, on_field, on_file)
1898
1899 # Read chunks of 1MiB and write to the parser, but never read more than
1900 # the given Content-Length, if any.
1901 content_length: int | float | bytes | None = headers.get("Content-Length")
1902 if content_length is not None:
1903 content_length = int(content_length)
1904 if content_length < 0:
1905 raise ValueError("Content-Length must be non-negative")
1906 else:
1907 content_length = float("inf")
1908 bytes_read = 0
1909
1910 while True:
1911 # Read only up to the Content-Length given.
1912 max_readable = int(min(content_length - bytes_read, chunk_size))
1913 buff = input_stream.read(max_readable)
1914
1915 # Write to the parser and update our length.
1916 parser.write(buff)
1917 bytes_read += len(buff)
1918
1919 # If we get a buffer that's smaller than the size requested, or if we
1920 # have read up to our content length, we're done.
1921 if len(buff) != max_readable or bytes_read == content_length:
1922 break
1923
1924 # Tell our parser that we're done writing data.
1925 parser.finalize()