1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import sys
7import tempfile
8from email.message import Message
9from enum import IntEnum
10from io import BufferedRandom, BytesIO
11from numbers import Number
12from typing import TYPE_CHECKING, cast
13
14from .decoders import Base64Decoder, QuotedPrintableDecoder
15from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
16
17if TYPE_CHECKING:
18 from collections.abc import Callable
19 from typing import Any, Literal, Protocol, TypeAlias, TypedDict
20
21 class SupportsRead(Protocol):
22 def read(self, __n: int) -> bytes: ...
23
24 class QuerystringCallbacks(TypedDict, total=False):
25 on_field_start: Callable[[], None]
26 on_field_name: Callable[[bytes, int, int], None]
27 on_field_data: Callable[[bytes, int, int], None]
28 on_field_end: Callable[[], None]
29 on_end: Callable[[], None]
30
31 class OctetStreamCallbacks(TypedDict, total=False):
32 on_start: Callable[[], None]
33 on_data: Callable[[bytes, int, int], None]
34 on_end: Callable[[], None]
35
36 class MultipartCallbacks(TypedDict, total=False):
37 on_part_begin: Callable[[], None]
38 on_part_data: Callable[[bytes, int, int], None]
39 on_part_end: Callable[[], None]
40 on_header_begin: Callable[[], None]
41 on_header_field: Callable[[bytes, int, int], None]
42 on_header_value: Callable[[bytes, int, int], None]
43 on_header_end: Callable[[], None]
44 on_headers_finished: Callable[[], None]
45 on_end: Callable[[], None]
46
47 class FileConfig(TypedDict, total=False):
48 UPLOAD_DIR: str | bytes | None
49 UPLOAD_DELETE_TMP: bool
50 UPLOAD_KEEP_FILENAME: bool
51 UPLOAD_KEEP_EXTENSIONS: bool
52 MAX_MEMORY_FILE_SIZE: int
53
54 class FormParserConfig(FileConfig):
55 UPLOAD_ERROR_ON_BAD_CTE: bool
56 MAX_BODY_SIZE: float
57 MAX_HEADER_COUNT: int
58 MAX_HEADER_SIZE: int
59
60 CallbackName: TypeAlias = Literal[
61 "start",
62 "data",
63 "end",
64 "field_start",
65 "field_name",
66 "field_data",
67 "field_end",
68 "part_begin",
69 "part_data",
70 "part_end",
71 "header_begin",
72 "header_field",
73 "header_value",
74 "header_end",
75 "headers_finished",
76 ]
77
78# Unique missing object.
79_missing = object()
80
81
82class QuerystringState(IntEnum):
83 """Querystring parser states.
84
85 These are used to keep track of the state of the parser, and are used to determine
86 what to do when new data is encountered.
87 """
88
89 BEFORE_FIELD = 0
90 FIELD_NAME = 1
91 FIELD_DATA = 2
92
93
94class MultipartState(IntEnum):
95 """Multipart parser states.
96
97 These are used to keep track of the state of the parser, and are used to determine
98 what to do when new data is encountered.
99 """
100
101 START = 0
102 START_BOUNDARY = 1
103 HEADER_FIELD_START = 2
104 HEADER_FIELD = 3
105 HEADER_VALUE_START = 4
106 HEADER_VALUE = 5
107 HEADER_VALUE_ALMOST_DONE = 6
108 HEADERS_ALMOST_DONE = 7
109 PART_DATA_START = 8
110 PART_DATA = 9
111 PART_DATA_END = 10
112 END_BOUNDARY = 11
113 END = 12
114
115
116# Flags for the multipart parser.
117FLAG_PART_BOUNDARY = 1
118FLAG_LAST_BOUNDARY = 2
119
120# Get constants. Since iterating over a str on Python 2 gives you a 1-length
121# string, but iterating over a bytes object on Python 3 gives you an integer,
122# we need to save these constants.
123CR = b"\r"[0]
124LF = b"\n"[0]
125COLON = b":"[0]
126SPACE = b" "[0]
127HYPHEN = b"-"[0]
128AMPERSAND = b"&"[0]
129SEMICOLON = b";"[0]
130LOWER_A = b"a"[0]
131LOWER_Z = b"z"[0]
132NULL = b"\x00"[0]
133
134# fmt: off
135# Mask for ASCII characters that can be http tokens.
136# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
137# and these: !#$%&'*+-.^_`|~
138TOKEN_CHARS_SET = frozenset(
139 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
140 b"abcdefghijklmnopqrstuvwxyz"
141 b"0123456789"
142 b"!#$%&'*+-.^_`|~")
143# fmt: on
144
145DEFAULT_MAX_HEADER_COUNT = 8
146"""Default maximum number of headers allowed per multipart part."""
147
148DEFAULT_MAX_HEADER_SIZE = 4096 + 128
149"""Default maximum size of a single multipart header line, including syntax overhead."""
150
151MAX_BOUNDARY_LENGTH = 256
152"""Maximum allowed length of a multipart boundary.
153
154[RFC 2046 §5.1.1](https://datatracker.ietf.org/doc/html/rfc2046#section-5.1.1)
155recommends boundaries be at most 70 bytes. 256 bytes is generous headroom over
156every HTTP client.
157"""
158
159
160def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
161 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
162 # Uses email.message.Message to parse the header as described in PEP 594.
163 # Ref: https://peps.python.org/pep-0594/#cgi
164 if not value:
165 return (b"", {})
166
167 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
168 if isinstance(value, bytes): # pragma: no cover
169 value = value.decode("latin-1")
170
171 # For types
172 assert isinstance(value, str), "Value should be a string by now"
173
174 # If we have no options, return the string as-is.
175 if ";" not in value:
176 return (value.lower().strip().encode("latin-1"), {})
177
178 # Split at the first semicolon, to get our value and then options.
179 # ctype, rest = value.split(b';', 1)
180 message = Message()
181 message["content-type"] = value
182 # `get_params()` can raise on malformed RFC 2231 headers found via fuzzing:
183 # - ValueError on oversized continuation indices (all supported versions).
184 # - TypeError on mixed `filename*` + `filename*0*` continuations (Python 3.12 only;
185 # 3.13+ silently picks a value).
186 # TODO: drop `TypeError` once Python 3.12 reaches EOL (October 2028).
187 try:
188 params = message.get_params()
189 except (TypeError, ValueError): # pragma: no cover
190 return (value.split(";", 1)[0].lower().strip().encode("latin-1"), {})
191 # If there were no parameters, this would have already returned above
192 assert params, "At least the content type value should be present"
193 ctype = params.pop(0)[0].encode("latin-1")
194 options: dict[bytes, bytes] = {}
195 for param in params:
196 key, value = param
197 # If the value returned from get_params() is a 3-tuple, the last
198 # element corresponds to the value.
199 # See: https://docs.python.org/3/library/email.compat32-message.html
200 if isinstance(value, tuple):
201 value = value[-1]
202 # If the value is a filename, we need to fix a bug on IE6 that sends
203 # the full file path instead of the filename.
204 if key == "filename":
205 if value[1:3] == ":\\" or value[:2] == "\\\\":
206 value = value.split("\\")[-1]
207 options[key.encode("latin-1")] = value.encode("latin-1")
208 return ctype, options
209
210
211class Field:
212 """A Field object represents a (parsed) form field. It represents a single
213 field with a corresponding name and value.
214
215 The name that a :class:`Field` will be instantiated with is the same name
216 that would be found in the following HTML::
217
218 <input name="name_goes_here" type="text"/>
219
220 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
221 will be called when data is written to the Field, and when the Field is
222 finalized, respectively.
223
224 Args:
225 name: The name of the form field.
226 content_type: The value of the Content-Type header for this field.
227 """
228
229 def __init__(self, name: bytes | None, *, content_type: str | None = None) -> None:
230 self._name = name
231 self._value: list[bytes] = []
232 self._content_type = content_type
233
234 # We cache the joined version of _value for speed.
235 self._cache = _missing
236
237 @classmethod
238 def from_value(cls, name: bytes, value: bytes | None) -> Field:
239 """Create an instance of a :class:`Field`, and set the corresponding
240 value - either None or an actual value. This method will also
241 finalize the Field itself.
242
243 Args:
244 name: the name of the form field.
245 value: the value of the form field - either a bytestring or None.
246
247 Returns:
248 A new instance of a [`Field`][python_multipart.Field].
249 """
250
251 f = cls(name)
252 if value is None:
253 f.set_none()
254 else:
255 f.write(value)
256 f.finalize()
257 return f
258
259 def write(self, data: bytes) -> int:
260 """Write some data into the form field.
261
262 Args:
263 data: The data to write to the field.
264
265 Returns:
266 The number of bytes written.
267 """
268 return self.on_data(data)
269
270 def on_data(self, data: bytes) -> int:
271 """This method is a callback that will be called whenever data is
272 written to the Field.
273
274 Args:
275 data: The data to write to the field.
276
277 Returns:
278 The number of bytes written.
279 """
280 self._value.append(data)
281 self._cache = _missing
282 return len(data)
283
284 def on_end(self) -> None:
285 """This method is called whenever the Field is finalized."""
286 if self._cache is _missing:
287 self._cache = b"".join(self._value)
288
289 def finalize(self) -> None:
290 """Finalize the form field."""
291 self.on_end()
292
293 def close(self) -> None:
294 """Close the Field object. This will free any underlying cache."""
295 # Free our value array.
296 if self._cache is _missing:
297 self._cache = b"".join(self._value)
298
299 del self._value
300
301 def set_none(self) -> None:
302 """Some fields in a querystring can possibly have a value of None - for
303 example, the string "foo&bar=&baz=asdf" will have a field with the
304 name "foo" and value None, one with name "bar" and value "", and one
305 with name "baz" and value "asdf". Since the write() interface doesn't
306 support writing None, this function will set the field value to None.
307 """
308 self._cache = None
309
310 @property
311 def field_name(self) -> bytes | None:
312 """This property returns the name of the field."""
313 return self._name
314
315 @property
316 def value(self) -> bytes | None:
317 """This property returns the value of the form field."""
318 if self._cache is _missing:
319 self._cache = b"".join(self._value)
320
321 assert isinstance(self._cache, bytes) or self._cache is None
322 return self._cache
323
324 @property
325 def content_type(self) -> str | None:
326 """This property returns the content_type value of the field."""
327 return self._content_type
328
329 def __eq__(self, other: object) -> bool:
330 if isinstance(other, Field):
331 return self.field_name == other.field_name and self.value == other.value
332 else:
333 return NotImplemented
334
335 def __repr__(self) -> str:
336 if self.value is not None and len(self.value) > 97:
337 # We get the repr, and then insert three dots before the final
338 # quote.
339 v = repr(self.value[:97])[:-1] + "...'"
340 else:
341 v = repr(self.value)
342
343 return f"{self.__class__.__name__}(field_name={self.field_name!r}, value={v})"
344
345
346class File:
347 """This class represents an uploaded file. It handles writing file data to
348 either an in-memory file or a temporary file on-disk, if the optional
349 threshold is passed.
350
351 There are some options that can be passed to the File to change behavior
352 of the class. Valid options are as follows:
353
354 | Name | Type | Default | Description |
355 |-----------------------|-------|---------|-------------|
356 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
357 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
358 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
359 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
360 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
361
362 Args:
363 file_name: The name of the file that this [`File`][python_multipart.File] represents.
364 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
365 the file was uploaded with Content-Type application/octet-stream.
366 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
367 content_type: The value of the Content-Type header.
368 """ # noqa: E501
369
370 def __init__(
371 self,
372 file_name: bytes | None,
373 field_name: bytes | None = None,
374 config: FileConfig = {},
375 *,
376 content_type: str | None = None,
377 ) -> None:
378 # Save configuration, set other variables default.
379 self.logger = logging.getLogger(__name__)
380 self._config = config
381 self._in_memory = True
382 self._bytes_written = 0
383 self._fileobj: BytesIO | BufferedRandom = BytesIO()
384
385 # Save the provided field/file name and content type.
386 self._field_name = field_name
387 self._file_name = file_name
388 self._content_type = content_type
389
390 # Our actual file name is None by default, since, depending on our
391 # config, we may not actually use the provided name.
392 self._actual_file_name: bytes | None = None
393
394 # Split the extension from the filename.
395 if file_name is not None:
396 # Extract just the basename to avoid directory traversal
397 basename = os.path.basename(file_name)
398 base, ext = os.path.splitext(basename)
399 self._file_base = base
400 self._ext = ext
401
402 @property
403 def field_name(self) -> bytes | None:
404 """The form field associated with this file. May be None if there isn't
405 one, for example when we have an application/octet-stream upload.
406 """
407 return self._field_name
408
409 @property
410 def file_name(self) -> bytes | None:
411 """The file name given in the upload request."""
412 return self._file_name
413
414 @property
415 def actual_file_name(self) -> bytes | None:
416 """The file name that this file is saved as. Will be None if it's not
417 currently saved on disk.
418 """
419 return self._actual_file_name
420
421 @property
422 def file_object(self) -> BytesIO | BufferedRandom:
423 """The file object that we're currently writing to. Note that this
424 will either be an instance of a :class:`io.BytesIO`, or a regular file
425 object.
426 """
427 return self._fileobj
428
429 @property
430 def size(self) -> int:
431 """The total size of this file, counted as the number of bytes that
432 currently have been written to the file.
433 """
434 return self._bytes_written
435
436 @property
437 def in_memory(self) -> bool:
438 """A boolean representing whether or not this file object is currently
439 stored in-memory or on-disk.
440 """
441 return self._in_memory
442
443 @property
444 def content_type(self) -> str | None:
445 """The Content-Type value for this part, if it was set."""
446 return self._content_type
447
448 def flush_to_disk(self) -> None:
449 """If the file is already on-disk, do nothing. Otherwise, copy from
450 the in-memory buffer to a disk file, and then reassign our internal
451 file object to this new disk file.
452
453 Note that if you attempt to flush a file that is already on-disk, a
454 warning will be logged to this module's logger.
455 """
456 if not self._in_memory:
457 self.logger.warning("Trying to flush to disk when we're not in memory")
458 return
459
460 # Go back to the start of our file.
461 self._fileobj.seek(0)
462
463 # Open a new file.
464 new_file = self._get_disk_file()
465
466 # Copy the file objects.
467 shutil.copyfileobj(self._fileobj, new_file)
468
469 # Seek to the new position in our new file.
470 new_file.seek(self._bytes_written)
471
472 # Reassign the fileobject.
473 old_fileobj = self._fileobj
474 self._fileobj = new_file
475
476 # We're no longer in memory.
477 self._in_memory = False
478
479 # Close the old file object.
480 old_fileobj.close()
481
482 def _get_disk_file(self) -> BufferedRandom:
483 """This function is responsible for getting a file object on-disk for us."""
484 self.logger.info("Opening a file on disk")
485
486 file_dir = self._config.get("UPLOAD_DIR")
487 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
488 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
489 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
490 tmp_file: None | BufferedRandom = None
491
492 # If we have a directory and are to keep the filename...
493 if file_dir is not None and keep_filename:
494 self.logger.info("Saving with filename in: %r", file_dir)
495
496 # Build our filename.
497 # TODO: what happens if we don't have a filename?
498 fname = self._file_base + self._ext if keep_extensions else self._file_base
499
500 path = os.path.join(file_dir, fname) # type: ignore[arg-type]
501 try:
502 self.logger.info("Opening file: %r", path)
503 tmp_file = open(path, "w+b")
504 except OSError:
505 tmp_file = None
506
507 self.logger.exception("Error opening temporary file")
508 raise FileError("Error opening temporary file: %r" % path)
509 else:
510 # Build options array.
511 # Note that on Python 3, tempfile doesn't support byte names. We
512 # encode our paths using the default filesystem encoding.
513 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
514
515 if file_dir is None:
516 dir = None
517 elif isinstance(file_dir, bytes):
518 dir = file_dir.decode(sys.getfilesystemencoding())
519 else:
520 dir = file_dir # pragma: no cover
521
522 # Create a temporary (named) file with the appropriate settings.
523 self.logger.info(
524 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
525 )
526 try:
527 tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
528 except OSError:
529 self.logger.exception("Error creating named temporary file")
530 raise FileError("Error creating named temporary file")
531
532 assert tmp_file is not None
533 # Encode filename as bytes.
534 if isinstance(tmp_file.name, str):
535 fname = tmp_file.name.encode(sys.getfilesystemencoding())
536 else:
537 fname = cast(bytes, tmp_file.name) # pragma: no cover
538
539 self._actual_file_name = fname
540 return tmp_file
541
542 def write(self, data: bytes) -> int:
543 """Write some data to the File.
544
545 :param data: a bytestring
546 """
547 return self.on_data(data)
548
549 def on_data(self, data: bytes) -> int:
550 """This method is a callback that will be called whenever data is
551 written to the File.
552
553 Args:
554 data: The data to write to the file.
555
556 Returns:
557 The number of bytes written.
558 """
559 bwritten = self._fileobj.write(data)
560
561 # If the bytes written isn't the same as the length, just return.
562 if bwritten != len(data):
563 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
564 return bwritten
565
566 # Keep track of how many bytes we've written.
567 self._bytes_written += bwritten
568
569 # If we're in-memory and are over our limit, we create a file.
570 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
571 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
572 self.logger.info("Flushing to disk")
573 self.flush_to_disk()
574
575 # Return the number of bytes written.
576 return bwritten
577
578 def on_end(self) -> None:
579 """This method is called whenever the Field is finalized."""
580 # Flush the underlying file object
581 self._fileobj.flush()
582
583 def finalize(self) -> None:
584 """Finalize the form file. This will not close the underlying file,
585 but simply signal that we are finished writing to the File.
586 """
587 self.on_end()
588
589 def close(self) -> None:
590 """Close the File object. This will actually close the underlying
591 file object (whether it's a :class:`io.BytesIO` or an actual file
592 object).
593 """
594 self._fileobj.close()
595
596 def __repr__(self) -> str:
597 return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})"
598
599
600class BaseParser:
601 """This class is the base class for all parsers. It contains the logic for
602 calling and adding callbacks.
603
604 A callback can be one of two different forms. "Notification callbacks" are
605 callbacks that are called when something happens - for example, when a new
606 part of a multipart message is encountered by the parser. "Data callbacks"
607 are called when we get some sort of data - for example, part of the body of
608 a multipart chunk. Notification callbacks are called with no parameters,
609 whereas data callbacks are called with three, as follows::
610
611 data_callback(data, start, end)
612
613 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
614 Python 3). "start" and "end" are integer indexes into the "data" string
615 that represent the data of interest. Thus, in a data callback, the slice
616 `data[start:end]` represents the data that the callback is "interested in".
617 The callback is not passed a copy of the data, since copying severely hurts
618 performance.
619 """
620
621 def __init__(self) -> None:
622 self.logger = logging.getLogger(__name__)
623 self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
624
625 def callback(
626 self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
627 ) -> None:
628 """This function calls a provided callback with some data. If the
629 callback is not set, will do nothing.
630
631 Args:
632 name: The name of the callback to call (as a string).
633 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
634 callback, and no parameters are given.
635 end: An integer that is passed to the data callback.
636 start: An integer that is passed to the data callback.
637 """
638 on_name = "on_" + name
639 func = self.callbacks.get(on_name)
640 if func is None:
641 return
642 func = cast("Callable[..., Any]", func)
643 # Depending on whether we're given a buffer...
644 if data is not None:
645 # Don't do anything if we have start == end.
646 if start is not None and start == end:
647 return
648
649 self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
650 func(data, start, end)
651 else:
652 self.logger.debug("Calling %s with no data", on_name)
653 func()
654
655 def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
656 """Update the function for a callback. Removes from the callbacks dict
657 if new_func is None.
658
659 :param name: The name of the callback to call (as a string).
660
661 :param new_func: The new function for the callback. If None, then the
662 callback will be removed (with no error if it does not
663 exist).
664 """
665 if new_func is None:
666 self.callbacks.pop("on_" + name, None) # type: ignore[misc]
667 else:
668 self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
669
670 def close(self) -> None:
671 pass # pragma: no cover
672
673 def finalize(self) -> None:
674 pass # pragma: no cover
675
676 def __repr__(self) -> str:
677 return "%s()" % self.__class__.__name__
678
679
680class OctetStreamParser(BaseParser):
681 """This parser parses an octet-stream request body and calls callbacks when
682 incoming data is received. Callbacks are as follows:
683
684 | Callback Name | Parameters | Description |
685 |----------------|-----------------|-----------------------------------------------------|
686 | on_start | None | Called when the first data is parsed. |
687 | on_data | data, start, end| Called for each data chunk that is parsed. |
688 | on_end | None | Called when the parser is finished parsing all data.|
689
690 Args:
691 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
692 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
693 """
694
695 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
696 super().__init__()
697 self.callbacks = callbacks
698 self._started = False
699
700 if not isinstance(max_size, Number) or max_size < 1:
701 raise ValueError("max_size must be a positive number, not %r" % max_size)
702 self.max_size: int | float = max_size
703 self._current_size = 0
704
705 def write(self, data: bytes) -> int:
706 """Write some data to the parser, which will perform size verification,
707 and then pass the data to the underlying callback.
708
709 Args:
710 data: The data to write to the parser.
711
712 Returns:
713 The number of bytes written.
714 """
715 if not self._started:
716 self.callback("start")
717 self._started = True
718
719 # Truncate data length.
720 data_len = len(data)
721 if (self._current_size + data_len) > self.max_size:
722 # We truncate the length of data that we are to process.
723 new_size = int(self.max_size - self._current_size)
724 self.logger.warning(
725 "Current size is %d (max %d), so truncating data length from %d to %d",
726 self._current_size,
727 self.max_size,
728 data_len,
729 new_size,
730 )
731 data_len = new_size
732
733 # Increment size, then callback, in case there's an exception.
734 self._current_size += data_len
735 self.callback("data", data, 0, data_len)
736 return data_len
737
738 def finalize(self) -> None:
739 """Finalize this parser, which signals to that we are finished parsing,
740 and sends the on_end callback.
741 """
742 self.callback("end")
743
744 def __repr__(self) -> str:
745 return "%s()" % self.__class__.__name__
746
747
748class QuerystringParser(BaseParser):
749 """This is a streaming querystring parser. It will consume data, and call
750 the callbacks given when it has data.
751
752 | Callback Name | Parameters | Description |
753 |----------------|-----------------|-----------------------------------------------------|
754 | on_field_start | None | Called when a new field is encountered. |
755 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
756 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
757 | on_field_end | None | Called when the end of a field is encountered. |
758 | on_end | None | Called when the parser is finished parsing all data.|
759
760 Args:
761 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
762 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
763 behavior of the parser changes as the following: if a field has a value with an equal sign
764 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
765 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
766 then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
767 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
768 """ # noqa: E501
769
770 state: QuerystringState
771
772 def __init__(
773 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
774 ) -> None:
775 super().__init__()
776 self.state = QuerystringState.BEFORE_FIELD
777 self._found_sep = False
778
779 self.callbacks = callbacks
780
781 # Max-size stuff
782 if not isinstance(max_size, Number) or max_size < 1:
783 raise ValueError("max_size must be a positive number, not %r" % max_size)
784 self.max_size: int | float = max_size
785 self._current_size = 0
786
787 # Should parsing be strict?
788 self.strict_parsing = strict_parsing
789
790 def write(self, data: bytes) -> int:
791 """Write some data to the parser, which will perform size verification,
792 parse into either a field name or value, and then pass the
793 corresponding data to the underlying callback. If an error is
794 encountered while parsing, a QuerystringParseError will be raised. The
795 "offset" attribute of the raised exception will be set to the offset in
796 the input data chunk (NOT the overall stream) that caused the error.
797
798 Args:
799 data: The data to write to the parser.
800
801 Returns:
802 The number of bytes written.
803 """
804 # Handle sizing.
805 data_len = len(data)
806 if (self._current_size + data_len) > self.max_size:
807 # We truncate the length of data that we are to process.
808 new_size = int(self.max_size - self._current_size)
809 self.logger.warning(
810 "Current size is %d (max %d), so truncating data length from %d to %d",
811 self._current_size,
812 self.max_size,
813 data_len,
814 new_size,
815 )
816 data_len = new_size
817
818 l = 0
819 try:
820 l = self._internal_write(data, data_len)
821 finally:
822 self._current_size += l
823
824 return l
825
826 def _internal_write(self, data: bytes, length: int) -> int:
827 state = self.state
828 strict_parsing = self.strict_parsing
829 found_sep = self._found_sep
830
831 i = 0
832 while i < length:
833 ch = data[i]
834
835 # Depending on our state...
836 if state == QuerystringState.BEFORE_FIELD:
837 # If the 'found_sep' flag is set, we've already encountered
838 # and skipped a single separator. If so, we check our strict
839 # parsing flag and decide what to do. Otherwise, we haven't
840 # yet reached a separator, and thus, if we do, we need to skip
841 # it as it will be the boundary between fields that's supposed
842 # to be there.
843 if ch == AMPERSAND or ch == SEMICOLON:
844 if found_sep:
845 # If we're parsing strictly, we disallow blank chunks.
846 if strict_parsing:
847 raise QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i, offset=i)
848 else:
849 self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
850 else:
851 # This case is when we're skipping the (first)
852 # separator between fields, so we just set our flag
853 # and continue on.
854 found_sep = True
855 else:
856 # Emit a field-start event, and go to that state. Also,
857 # reset the "found_sep" flag, for the next time we get to
858 # this state.
859 self.callback("field_start")
860 i -= 1
861 state = QuerystringState.FIELD_NAME
862 found_sep = False
863
864 elif state == QuerystringState.FIELD_NAME:
865 # Try and find a separator - we ensure that, if we do, we only
866 # look for the equal sign before it.
867 sep_pos = data.find(b"&", i)
868 if sep_pos == -1:
869 sep_pos = data.find(b";", i)
870
871 # See if we can find an equals sign in the remaining data. If
872 # so, we can immediately emit the field name and jump to the
873 # data state.
874 if sep_pos != -1:
875 equals_pos = data.find(b"=", i, sep_pos)
876 else:
877 equals_pos = data.find(b"=", i)
878
879 if equals_pos != -1:
880 # Emit this name.
881 self.callback("field_name", data, i, equals_pos)
882
883 # Jump i to this position. Note that it will then have 1
884 # added to it below, which means the next iteration of this
885 # loop will inspect the character after the equals sign.
886 i = equals_pos
887 state = QuerystringState.FIELD_DATA
888 else:
889 # No equals sign found.
890 if not strict_parsing:
891 # See also comments in the QuerystringState.FIELD_DATA case below.
892 # If we found the separator, we emit the name and just
893 # end - there's no data callback at all (not even with
894 # a blank value).
895 if sep_pos != -1:
896 self.callback("field_name", data, i, sep_pos)
897 self.callback("field_end")
898
899 i = sep_pos - 1
900 state = QuerystringState.BEFORE_FIELD
901 else:
902 # Otherwise, no separator in this block, so the
903 # rest of this chunk must be a name.
904 self.callback("field_name", data, i, length)
905 i = length
906
907 else:
908 # We're parsing strictly. If we find a separator,
909 # this is an error - we require an equals sign.
910 if sep_pos != -1:
911 raise QuerystringParseError(
912 "When strict_parsing is True, we require an "
913 "equals sign in all field chunks. Did not "
914 "find one in the chunk that starts at %d" % (i,),
915 offset=i,
916 )
917
918 # No separator in the rest of this chunk, so it's just
919 # a field name.
920 self.callback("field_name", data, i, length)
921 i = length
922
923 elif state == QuerystringState.FIELD_DATA:
924 # Try finding either an ampersand or a semicolon after this
925 # position.
926 sep_pos = data.find(b"&", i)
927 if sep_pos == -1:
928 sep_pos = data.find(b";", i)
929
930 # If we found it, callback this bit as data and then go back
931 # to expecting to find a field.
932 if sep_pos != -1:
933 self.callback("field_data", data, i, sep_pos)
934 self.callback("field_end")
935
936 # Note that we go to the separator, which brings us to the
937 # "before field" state. This allows us to properly emit
938 # "field_start" events only when we actually have data for
939 # a field of some sort.
940 i = sep_pos - 1
941 state = QuerystringState.BEFORE_FIELD
942
943 # Otherwise, emit the rest as data and finish.
944 else:
945 self.callback("field_data", data, i, length)
946 i = length
947
948 else: # pragma: no cover (error case)
949 msg = "Reached an unknown state %d at %d" % (state, i)
950 self.logger.warning(msg)
951 raise QuerystringParseError(msg, offset=i)
952
953 i += 1
954
955 self.state = state
956 self._found_sep = found_sep
957 return length
958
959 def finalize(self) -> None:
960 """Finalize this parser, which signals to that we are finished parsing,
961 if we're still in the middle of a field, an on_field_end callback, and
962 then the on_end callback.
963 """
964 # If we're currently in the middle of a field, we finish it.
965 if self.state in (QuerystringState.FIELD_DATA, QuerystringState.FIELD_NAME):
966 self.callback("field_end")
967 self.callback("end")
968
969 def __repr__(self) -> str:
970 return "{}(strict_parsing={!r}, max_size={!r})".format(
971 self.__class__.__name__, self.strict_parsing, self.max_size
972 )
973
974
975class MultipartParser(BaseParser):
976 """This class is a streaming multipart/form-data parser.
977
978 | Callback Name | Parameters | Description |
979 |--------------------|-----------------|-------------|
980 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
981 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
982 | on_part_end | None | Called when the end of a part is reached. |
983 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
984 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
985 | on_header_value | data, start, end| Called when we get data for a header. |
986 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
987 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
988 | on_end | None | Called when the parser is finished parsing all data. |
989
990 Args:
991 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
992 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
993 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
994 max_header_count: The maximum number of headers allowed per part.
995 max_header_size: The maximum size of a single header line (excluding the trailing CRLF).
996 """ # noqa: E501
997
998 def __init__(
999 self,
1000 boundary: bytes | str,
1001 callbacks: MultipartCallbacks = {},
1002 max_size: float = float("inf"),
1003 *,
1004 max_header_count: int = DEFAULT_MAX_HEADER_COUNT,
1005 max_header_size: int = DEFAULT_MAX_HEADER_SIZE,
1006 ) -> None:
1007 # Initialize parser state.
1008 super().__init__()
1009 self.state = MultipartState.START
1010 self.index = self.flags = 0
1011
1012 self.callbacks = callbacks
1013
1014 if not isinstance(max_size, Number) or max_size < 1:
1015 raise ValueError("max_size must be a positive number, not %r" % max_size)
1016 self.max_size = max_size
1017 self._current_size = 0
1018
1019 self.max_header_count = max_header_count
1020 self._current_header_count = 0
1021
1022 self.max_header_size = max_header_size
1023 self._current_header_size = 0
1024
1025 # Setup marks. These are used to track the state of data received.
1026 self.marks: dict[str, int] = {}
1027
1028 # Save our boundary.
1029 if isinstance(boundary, str): # pragma: no cover
1030 boundary = boundary.encode("latin-1")
1031 if len(boundary) > MAX_BOUNDARY_LENGTH:
1032 raise FormParserError(f"Boundary length {len(boundary)} exceeds maximum of {MAX_BOUNDARY_LENGTH}")
1033 self.boundary = b"\r\n--" + boundary
1034
1035 def write(self, data: bytes) -> int:
1036 """Write some data to the parser, which will perform size verification,
1037 and then parse the data into the appropriate location (e.g. header,
1038 data, etc.), and pass this on to the underlying callback. If an error
1039 is encountered, a MultipartParseError will be raised. The "offset"
1040 attribute on the raised exception will be set to the offset of the byte
1041 in the input chunk that caused the error.
1042
1043 Args:
1044 data: The data to write to the parser.
1045
1046 Returns:
1047 The number of bytes written.
1048 """
1049 # Handle sizing.
1050 data_len = len(data)
1051 if (self._current_size + data_len) > self.max_size:
1052 # We truncate the length of data that we are to process.
1053 new_size = int(self.max_size - self._current_size)
1054 self.logger.warning(
1055 "Current size is %d (max %d), so truncating data length from %d to %d",
1056 self._current_size,
1057 self.max_size,
1058 data_len,
1059 new_size,
1060 )
1061 data_len = new_size
1062
1063 l = 0
1064 try:
1065 l = self._internal_write(data, data_len)
1066 finally:
1067 self._current_size += l
1068
1069 return l
1070
1071 def _internal_write(self, data: bytes, length: int) -> int:
1072 # Get values from locals.
1073 boundary = self.boundary
1074
1075 # Get our state, flags and index. These are persisted between calls to
1076 # this function.
1077 state = self.state
1078 index = self.index
1079 flags = self.flags
1080 current_header_count = self._current_header_count
1081 current_header_size = self._current_header_size
1082
1083 # Our index defaults to 0.
1084 i = 0
1085
1086 def advance_header_size(amount: int = 1) -> None:
1087 nonlocal current_header_size
1088 current_header_size += amount
1089 if current_header_size > self.max_header_size:
1090 raise MultipartParseError("Maximum header size exceeded", offset=i)
1091
1092 # Set a mark.
1093 def set_mark(name: str) -> None:
1094 self.marks[name] = i
1095
1096 # Remove a mark.
1097 def delete_mark(name: str, reset: bool = False) -> None:
1098 self.marks.pop(name, None)
1099
1100 # Helper function that makes calling a callback with data easier. The
1101 # 'remaining' parameter will callback from the marked value until the
1102 # end of the buffer, and reset the mark, instead of deleting it. This
1103 # is used at the end of the function to call our callbacks with any
1104 # remaining data in this chunk.
1105 def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
1106 marked_index = self.marks.get(name)
1107 if marked_index is None:
1108 return
1109
1110 # Otherwise, we call it from the mark to the current byte we're
1111 # processing.
1112 if end_i <= marked_index:
1113 # There is no additional data to send.
1114 pass
1115 elif marked_index >= 0:
1116 # We are emitting data from the local buffer.
1117 self.callback(name, data, marked_index, end_i)
1118 else:
1119 # Some of the data comes from a partial boundary match.
1120 # and requires look-behind.
1121 # We need to use self.flags (and not flags) because we care about
1122 # the state when we entered the loop.
1123 lookbehind_len = -marked_index
1124 if lookbehind_len <= len(boundary):
1125 self.callback(name, boundary, 0, lookbehind_len)
1126 elif self.flags & FLAG_PART_BOUNDARY:
1127 lookback = boundary + b"\r\n"
1128 self.callback(name, lookback, 0, lookbehind_len)
1129 elif self.flags & FLAG_LAST_BOUNDARY:
1130 lookback = boundary + b"--\r\n"
1131 self.callback(name, lookback, 0, lookbehind_len)
1132 else: # pragma: no cover (error case)
1133 self.logger.warning("Look-back buffer error")
1134
1135 if end_i > 0:
1136 self.callback(name, data, 0, end_i)
1137 # If we're getting remaining data, we have got all the data we
1138 # can be certain is not a boundary, leaving only a partial boundary match.
1139 if remaining:
1140 self.marks[name] = end_i - length
1141 else:
1142 self.marks.pop(name, None)
1143
1144 # For each byte...
1145 while i < length:
1146 c = data[i]
1147
1148 if state == MultipartState.START:
1149 # Skip leading newlines
1150 if c == CR or c == LF:
1151 i = data.find(b"-", i)
1152 if i == -1:
1153 # No boundary candidate in this chunk, so ignore the content after the leading CR/LF.
1154 i = length
1155 break
1156 continue
1157
1158 # index is used as in index into our boundary. Set to 0.
1159 index = 0
1160
1161 # Move to the next state, but decrement i so that we re-process
1162 # this character.
1163 state = MultipartState.START_BOUNDARY
1164 i -= 1
1165
1166 elif state == MultipartState.START_BOUNDARY:
1167 # Check to ensure that the last 2 characters in our boundary
1168 # are CRLF.
1169 if index == len(boundary) - 2:
1170 if c == HYPHEN:
1171 # Potential empty message.
1172 state = MultipartState.END_BOUNDARY
1173 elif c != CR:
1174 # Error!
1175 msg = "Did not find CR at end of boundary (%d)" % (i,)
1176 self.logger.warning(msg)
1177 raise MultipartParseError(msg, offset=i)
1178
1179 index += 1
1180
1181 elif index == len(boundary) - 2 + 1:
1182 if c != LF:
1183 msg = "Did not find LF at end of boundary (%d)" % (i,)
1184 self.logger.warning(msg)
1185 raise MultipartParseError(msg, offset=i)
1186
1187 # The index is now used for indexing into our boundary.
1188 index = 0
1189
1190 # Callback for the start of a part.
1191 self.callback("part_begin")
1192 current_header_count = 0
1193 current_header_size = 0
1194
1195 # Move to the next character and state.
1196 state = MultipartState.HEADER_FIELD_START
1197
1198 else:
1199 # Check to ensure our boundary matches
1200 if c != boundary[index + 2]:
1201 msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
1202 self.logger.warning(msg)
1203 raise MultipartParseError(msg, offset=i)
1204
1205 # Increment index into boundary and continue.
1206 index += 1
1207
1208 elif state == MultipartState.HEADER_FIELD_START:
1209 # Mark the start of a header field here, reset the index, and
1210 # continue parsing our header field.
1211 index = 0
1212
1213 if c != CR:
1214 current_header_count += 1
1215 if current_header_count > self.max_header_count:
1216 raise MultipartParseError("Maximum header count exceeded", offset=i)
1217 current_header_size = 0
1218
1219 # Set a mark of our header field.
1220 set_mark("header_field")
1221
1222 # Notify that we're starting a header if the next character is
1223 # not a CR; a CR at the beginning of the header will cause us
1224 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1225 # below.
1226 if c != CR:
1227 self.callback("header_begin")
1228
1229 # Move to parsing header fields.
1230 state = MultipartState.HEADER_FIELD
1231 i -= 1
1232
1233 elif state == MultipartState.HEADER_FIELD:
1234 # If we've reached a CR at the beginning of a header, it means
1235 # that we've reached the second of 2 newlines, and so there are
1236 # no more headers to parse.
1237 if c == CR and index == 0:
1238 delete_mark("header_field")
1239 state = MultipartState.HEADERS_ALMOST_DONE
1240 i += 1
1241 continue
1242
1243 # Increment our index in the header.
1244 index += 1
1245
1246 # If we've reached a colon, we're done with this header.
1247 if c == COLON:
1248 advance_header_size()
1249 # A 0-length header is an error.
1250 if index == 1:
1251 msg = "Found 0-length header at %d" % (i,)
1252 self.logger.warning(msg)
1253 raise MultipartParseError(msg, offset=i)
1254
1255 # Call our callback with the header field.
1256 data_callback("header_field", i)
1257
1258 # Move to parsing the header value.
1259 state = MultipartState.HEADER_VALUE_START
1260
1261 elif c not in TOKEN_CHARS_SET:
1262 msg = "Found invalid character %r in header at %d" % (c, i)
1263 self.logger.warning(msg)
1264 raise MultipartParseError(msg, offset=i)
1265 else:
1266 advance_header_size()
1267
1268 elif state == MultipartState.HEADER_VALUE_START:
1269 # Skip leading spaces.
1270 if c == SPACE:
1271 advance_header_size()
1272 i += 1
1273 continue
1274
1275 # Mark the start of the header value.
1276 set_mark("header_value")
1277
1278 # Move to the header-value state, reprocessing this character.
1279 state = MultipartState.HEADER_VALUE
1280 i -= 1
1281
1282 elif state == MultipartState.HEADER_VALUE:
1283 # If we've got a CR, we're nearly done our headers. Otherwise,
1284 # we do nothing and just move past this character.
1285 if c == CR:
1286 data_callback("header_value", i)
1287 self.callback("header_end")
1288 current_header_size = 0
1289 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1290 else:
1291 advance_header_size()
1292
1293 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1294 # The last character should be a LF. If not, it's an error.
1295 if c != LF:
1296 msg = f"Did not find LF character at end of header (found {c!r})"
1297 self.logger.warning(msg)
1298 raise MultipartParseError(msg, offset=i)
1299
1300 # Move back to the start of another header. Note that if that
1301 # state detects ANOTHER newline, it'll trigger the end of our
1302 # headers.
1303 state = MultipartState.HEADER_FIELD_START
1304
1305 elif state == MultipartState.HEADERS_ALMOST_DONE:
1306 # We're almost done our headers. This is reached when we parse
1307 # a CR at the beginning of a header, so our next character
1308 # should be a LF, or it's an error.
1309 if c != LF:
1310 msg = f"Did not find LF at end of headers (found {c!r})"
1311 self.logger.warning(msg)
1312 raise MultipartParseError(msg, offset=i)
1313
1314 self.callback("headers_finished")
1315 state = MultipartState.PART_DATA_START
1316
1317 elif state == MultipartState.PART_DATA_START:
1318 # Mark the start of our part data.
1319 set_mark("part_data")
1320
1321 # Start processing part data, including this character.
1322 state = MultipartState.PART_DATA
1323 i -= 1
1324
1325 elif state == MultipartState.PART_DATA:
1326 # We're processing our part data right now. During this, we
1327 # need to efficiently search for our boundary, since any data
1328 # on any number of lines can be a part of the current data.
1329
1330 # Save the current value of our index. We use this in case we
1331 # find part of a boundary, but it doesn't match fully.
1332 prev_index = index
1333
1334 # Set up variables.
1335 boundary_length = len(boundary)
1336 data_length = length
1337
1338 # If our index is 0, we're starting a new part, so start our
1339 # search.
1340 if index == 0:
1341 # The most common case is likely to be that the whole
1342 # boundary is present in the buffer.
1343 # Calling `find` is much faster than iterating here.
1344 i0 = data.find(boundary, i, data_length)
1345 if i0 >= 0:
1346 # We matched the whole boundary string.
1347 index = boundary_length - 1
1348 i = i0 + boundary_length - 1
1349 else:
1350 # No match found for whole string.
1351 # There may be a partial boundary at the end of the
1352 # data, which the find will not match.
1353 # Since the length to be searched is limited to the
1354 # boundary length, scan the tail for boundary[0] via
1355 # bytes.find (C-level) to keep cost off the Python loop.
1356 i = max(i, data_length - boundary_length)
1357 j = data.find(boundary[:1], i, data_length - 1)
1358 i = j if j >= 0 else data_length - 1
1359
1360 c = data[i]
1361
1362 # Now, we have a couple of cases here. If our index is before
1363 # the end of the boundary...
1364 if index < boundary_length:
1365 # If the character matches...
1366 if boundary[index] == c:
1367 # The current character matches, so continue!
1368 index += 1
1369 else:
1370 index = 0
1371
1372 # Our index is equal to the length of our boundary!
1373 elif index == boundary_length:
1374 # First we increment it.
1375 index += 1
1376
1377 # Now, if we've reached a newline, we need to set this as
1378 # the potential end of our boundary.
1379 if c == CR:
1380 flags |= FLAG_PART_BOUNDARY
1381
1382 # Otherwise, if this is a hyphen, we might be at the last
1383 # of all boundaries.
1384 elif c == HYPHEN:
1385 flags |= FLAG_LAST_BOUNDARY
1386
1387 # Otherwise, we reset our index, since this isn't either a
1388 # newline or a hyphen.
1389 else:
1390 index = 0
1391
1392 # Our index is right after the part boundary, which should be
1393 # a LF.
1394 elif index == boundary_length + 1:
1395 # If we're at a part boundary (i.e. we've seen a CR
1396 # character already)...
1397 if flags & FLAG_PART_BOUNDARY:
1398 # We need a LF character next.
1399 if c == LF:
1400 # Unset the part boundary flag.
1401 flags &= ~FLAG_PART_BOUNDARY
1402
1403 # We have identified a boundary, callback for any data before it.
1404 data_callback("part_data", i - index)
1405 # Callback indicating that we've reached the end of
1406 # a part, and are starting a new one.
1407 self.callback("part_end")
1408 self.callback("part_begin")
1409 current_header_count = 0
1410 current_header_size = 0
1411
1412 # Move to parsing new headers.
1413 index = 0
1414 state = MultipartState.HEADER_FIELD_START
1415 i += 1
1416 continue
1417
1418 # We didn't find an LF character, so no match. Reset
1419 # our index and clear our flag.
1420 index = 0
1421 flags &= ~FLAG_PART_BOUNDARY
1422
1423 # Otherwise, if we're at the last boundary (i.e. we've
1424 # seen a hyphen already)...
1425 elif flags & FLAG_LAST_BOUNDARY:
1426 # We need a second hyphen here.
1427 if c == HYPHEN:
1428 # We have identified a boundary, callback for any data before it.
1429 data_callback("part_data", i - index)
1430 # Callback to end the current part, and then the
1431 # message.
1432 self.callback("part_end")
1433 self.callback("end")
1434 state = MultipartState.END
1435 else:
1436 # No match, so reset index.
1437 index = 0
1438
1439 # Otherwise, our index is 0. If the previous index is not, it
1440 # means we reset something, and we need to take the data we
1441 # thought was part of our boundary and send it along as actual
1442 # data.
1443 if index == 0 and prev_index > 0:
1444 # Overwrite our previous index.
1445 prev_index = 0
1446
1447 # Re-consider the current character, since this could be
1448 # the start of the boundary itself.
1449 i -= 1
1450
1451 elif state == MultipartState.END_BOUNDARY:
1452 if index == len(boundary) - 2 + 1:
1453 if c != HYPHEN:
1454 msg = "Did not find - at end of boundary (%d)" % (i,)
1455 self.logger.warning(msg)
1456 raise MultipartParseError(msg, offset=i)
1457 index += 1
1458 self.callback("end")
1459 state = MultipartState.END
1460
1461 elif state == MultipartState.END:
1462 # Silently discard any epilogue data (RFC 2046 section 5.1.1 allows a CRLF and optional
1463 # epilogue after the closing boundary). Django and Werkzeug do the same.
1464 i = length
1465 break
1466
1467 else: # pragma: no cover (error case)
1468 # We got into a strange state somehow! Just stop processing.
1469 msg = "Reached an unknown state %d at %d" % (state, i)
1470 self.logger.warning(msg)
1471 raise MultipartParseError(msg, offset=i)
1472
1473 # Move to the next byte.
1474 i += 1
1475
1476 # We call our callbacks with any remaining data. Note that we pass
1477 # the 'remaining' flag, which sets the mark back to 0 instead of
1478 # deleting it, if it's found. This is because, if the mark is found
1479 # at this point, we assume that there's data for one of these things
1480 # that has been parsed, but not yet emitted. And, as such, it implies
1481 # that we haven't yet reached the end of this 'thing'. So, by setting
1482 # the mark to 0, we cause any data callbacks that take place in future
1483 # calls to this function to start from the beginning of that buffer.
1484 data_callback("header_field", length, True)
1485 data_callback("header_value", length, True)
1486 data_callback("part_data", length - index, True)
1487
1488 # Save values to locals.
1489 self.state = state
1490 self.index = index
1491 self.flags = flags
1492 self._current_header_count = current_header_count
1493 self._current_header_size = current_header_size
1494
1495 # Return our data length to indicate no errors, and that we processed
1496 # all of it.
1497 return length
1498
1499 def finalize(self) -> None:
1500 """Finalize this parser, which signals to that we are finished parsing.
1501
1502 Note: It does not currently, but in the future, it will verify that we
1503 are in the final state of the parser (i.e. the end of the multipart
1504 message is well-formed), and, if not, throw an error.
1505 """
1506 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1507 # error or otherwise state that we're not finished parsing.
1508 pass
1509
1510 def __repr__(self) -> str:
1511 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1512
1513
1514class FormParser:
1515 """This class is the all-in-one form parser. Given all the information
1516 necessary to parse a form, it will instantiate the correct parser, create
1517 the proper :class:`Field` and :class:`File` classes to store the data that
1518 is parsed, and call the two given callbacks with each field and file as
1519 they become available.
1520
1521 Args:
1522 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1523 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1524 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1525 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1526 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1527 in the Content-Type header, as a bytestring.
1528 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1529 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1530 manually.
1531 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1532 and then any keys present in this dictionary will overwrite the default values.
1533 """
1534
1535 #: This is the default configuration for our form parser.
1536 #: Note: all file sizes should be in bytes.
1537 DEFAULT_CONFIG: FormParserConfig = {
1538 "MAX_BODY_SIZE": float("inf"),
1539 "MAX_HEADER_COUNT": DEFAULT_MAX_HEADER_COUNT,
1540 "MAX_HEADER_SIZE": DEFAULT_MAX_HEADER_SIZE,
1541 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1542 "UPLOAD_DIR": None,
1543 "UPLOAD_DELETE_TMP": True,
1544 "UPLOAD_KEEP_FILENAME": False,
1545 "UPLOAD_KEEP_EXTENSIONS": False,
1546 # Error on invalid Content-Transfer-Encoding?
1547 "UPLOAD_ERROR_ON_BAD_CTE": False,
1548 }
1549
1550 def __init__(
1551 self,
1552 content_type: str,
1553 on_field: Callable[[Field], None] | None,
1554 on_file: Callable[[File], None] | None,
1555 on_end: Callable[[], None] | None = None,
1556 boundary: bytes | str | None = None,
1557 file_name: bytes | None = None,
1558 config: dict[Any, Any] = {},
1559 ) -> None:
1560 self.logger = logging.getLogger(__name__)
1561
1562 # Save variables.
1563 self.content_type = content_type
1564 self.boundary = boundary
1565 self.bytes_received = 0
1566 self.parser = None
1567
1568 # Save callbacks.
1569 self.on_field = on_field
1570 self.on_file = on_file
1571 self.on_end = on_end
1572
1573 # Set configuration options.
1574 self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
1575 self.config.update(config) # type: ignore[typeddict-item]
1576
1577 parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
1578
1579 # Depending on the Content-Type, we instantiate the correct parser.
1580 if content_type == "application/octet-stream":
1581 file: File | None = None
1582
1583 def on_start() -> None:
1584 nonlocal file
1585 file = File(file_name, None, config=self.config)
1586
1587 def on_data(data: bytes, start: int, end: int) -> None:
1588 nonlocal file
1589 assert file is not None
1590 file.write(data[start:end])
1591
1592 def _on_end() -> None:
1593 nonlocal file
1594 assert file is not None
1595 # Finalize the file itself.
1596 file.finalize()
1597
1598 # Call our callback.
1599 if on_file:
1600 on_file(file)
1601
1602 # Call the on-end callback.
1603 if self.on_end is not None:
1604 self.on_end()
1605
1606 # Instantiate an octet-stream parser
1607 parser = OctetStreamParser(
1608 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1609 max_size=self.config["MAX_BODY_SIZE"],
1610 )
1611
1612 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1613 name_buffer: list[bytes] = []
1614
1615 f: Field | None = None
1616
1617 def on_field_start() -> None:
1618 pass
1619
1620 def on_field_name(data: bytes, start: int, end: int) -> None:
1621 name_buffer.append(data[start:end])
1622
1623 def on_field_data(data: bytes, start: int, end: int) -> None:
1624 nonlocal f
1625 if f is None:
1626 f = Field(b"".join(name_buffer))
1627 del name_buffer[:]
1628 f.write(data[start:end])
1629
1630 def on_field_end() -> None:
1631 nonlocal f
1632 # Finalize and call callback.
1633 if f is None:
1634 # If we get here, it's because there was no field data.
1635 # We create a field, set it to None, and then continue.
1636 f = Field(b"".join(name_buffer))
1637 del name_buffer[:]
1638 f.set_none()
1639
1640 f.finalize()
1641 if on_field:
1642 on_field(f)
1643 f = None
1644
1645 def _on_end() -> None:
1646 if self.on_end is not None:
1647 self.on_end()
1648
1649 # Instantiate parser.
1650 parser = QuerystringParser(
1651 callbacks={
1652 "on_field_start": on_field_start,
1653 "on_field_name": on_field_name,
1654 "on_field_data": on_field_data,
1655 "on_field_end": on_field_end,
1656 "on_end": _on_end,
1657 },
1658 max_size=self.config["MAX_BODY_SIZE"],
1659 )
1660
1661 elif content_type == "multipart/form-data":
1662 if boundary is None:
1663 self.logger.error("No boundary given")
1664 raise FormParserError("No boundary given")
1665
1666 header_name: list[bytes] = []
1667 header_value: list[bytes] = []
1668 headers: dict[bytes, bytes] = {}
1669
1670 f_multi: File | Field | None = None
1671 writer: File | Field | Base64Decoder | QuotedPrintableDecoder | None = None
1672 is_file = False
1673
1674 def on_part_begin() -> None:
1675 # Reset headers in case this isn't the first part.
1676 nonlocal headers
1677 headers = {}
1678
1679 def on_part_data(data: bytes, start: int, end: int) -> None:
1680 nonlocal writer
1681 assert writer is not None
1682 writer.write(data[start:end])
1683 # TODO: check for error here.
1684
1685 def on_part_end() -> None:
1686 nonlocal f_multi, is_file
1687 assert f_multi is not None
1688 f_multi.finalize()
1689 if is_file:
1690 if on_file:
1691 assert isinstance(f_multi, File)
1692 on_file(f_multi)
1693 else:
1694 if on_field:
1695 assert isinstance(f_multi, Field)
1696 on_field(f_multi)
1697
1698 def on_header_field(data: bytes, start: int, end: int) -> None:
1699 header_name.append(data[start:end])
1700
1701 def on_header_value(data: bytes, start: int, end: int) -> None:
1702 header_value.append(data[start:end])
1703
1704 def on_header_end() -> None:
1705 headers[b"".join(header_name).lower()] = b"".join(header_value)
1706 del header_name[:]
1707 del header_value[:]
1708
1709 def on_headers_finished() -> None:
1710 nonlocal is_file, f_multi, writer
1711 # Reset the 'is file' flag.
1712 is_file = False
1713
1714 # Parse the content-disposition header.
1715 content_disp = headers.get(b"content-disposition")
1716 disp, options = parse_options_header(content_disp)
1717
1718 # Get the field and filename.
1719 field_name = options.get(b"name")
1720 file_name = options.get(b"filename")
1721 # RFC 7578 §4.2: each part MUST have a Content-Disposition header with a "name" parameter.
1722 if field_name is None:
1723 raise FormParserError(f'Field name not found in Content-Disposition: "{content_disp!r}"')
1724
1725 # Create the proper class.
1726 content_type_b = headers.get(b"content-type")
1727 content_type = content_type_b.decode("latin-1") if content_type_b is not None else None
1728 if file_name is None:
1729 f_multi = Field(field_name, content_type=content_type)
1730 else:
1731 f_multi = File(file_name, field_name, config=self.config, content_type=content_type)
1732 is_file = True
1733
1734 # Parse the given Content-Transfer-Encoding to determine what
1735 # we need to do with the incoming data.
1736 # TODO: check that we properly handle 8bit / 7bit encoding.
1737 # RFC 2045 section 6.1: Content-Transfer-Encoding values are case-insensitive.
1738 # https://www.rfc-editor.org/rfc/rfc2045#section-6.1
1739 transfer_encoding = headers.get(b"content-transfer-encoding", b"7bit").lower()
1740
1741 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1742 writer = f_multi
1743
1744 elif transfer_encoding == b"base64":
1745 writer = Base64Decoder(f_multi)
1746
1747 elif transfer_encoding == b"quoted-printable":
1748 writer = QuotedPrintableDecoder(f_multi)
1749
1750 else:
1751 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1752 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1753 raise FormParserError(f'Unknown Content-Transfer-Encoding "{transfer_encoding!r}"')
1754 else:
1755 # If we aren't erroring, then we just treat this as an
1756 # unencoded Content-Transfer-Encoding.
1757 writer = f_multi
1758
1759 def _on_end() -> None:
1760 nonlocal writer
1761 if writer is not None:
1762 writer.finalize()
1763 if self.on_end is not None:
1764 self.on_end()
1765
1766 # Instantiate a multipart parser.
1767 parser = MultipartParser(
1768 boundary,
1769 callbacks={
1770 "on_part_begin": on_part_begin,
1771 "on_part_data": on_part_data,
1772 "on_part_end": on_part_end,
1773 "on_header_field": on_header_field,
1774 "on_header_value": on_header_value,
1775 "on_header_end": on_header_end,
1776 "on_headers_finished": on_headers_finished,
1777 "on_end": _on_end,
1778 },
1779 max_size=self.config["MAX_BODY_SIZE"],
1780 max_header_count=self.config["MAX_HEADER_COUNT"],
1781 max_header_size=self.config["MAX_HEADER_SIZE"],
1782 )
1783
1784 else:
1785 self.logger.warning("Unknown Content-Type: %r", content_type)
1786 raise FormParserError(f"Unknown Content-Type: {content_type}")
1787
1788 self.parser = parser
1789
1790 def write(self, data: bytes) -> int:
1791 """Write some data. The parser will forward this to the appropriate
1792 underlying parser.
1793
1794 Args:
1795 data: The data to write.
1796
1797 Returns:
1798 The number of bytes processed.
1799 """
1800 self.bytes_received += len(data)
1801 # TODO: check the parser's return value for errors?
1802 assert self.parser is not None
1803 return self.parser.write(data)
1804
1805 def finalize(self) -> None:
1806 """Finalize the parser."""
1807 if self.parser is not None and hasattr(self.parser, "finalize"):
1808 self.parser.finalize()
1809
1810 def close(self) -> None:
1811 """Close the parser."""
1812 if self.parser is not None and hasattr(self.parser, "close"):
1813 self.parser.close()
1814
1815 def __repr__(self) -> str:
1816 return f"{self.__class__.__name__}(content_type={self.content_type!r}, parser={self.parser!r})"
1817
1818
1819def create_form_parser(
1820 headers: dict[str, bytes],
1821 on_field: Callable[[Field], None] | None,
1822 on_file: Callable[[File], None] | None,
1823 config: dict[Any, Any] = {},
1824) -> FormParser:
1825 """This function is a helper function to aid in creating a FormParser
1826 instances. Given a dictionary-like headers object, it will determine
1827 the correct information needed, instantiate a FormParser with the
1828 appropriate values and given callbacks, and then return the corresponding
1829 parser.
1830
1831 Args:
1832 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1833 on_field: Callback to call with each parsed field.
1834 on_file: Callback to call with each parsed file.
1835 config: Configuration variables to pass to the FormParser.
1836 """
1837 content_type: str | bytes | None = headers.get("Content-Type")
1838 if content_type is None:
1839 logging.getLogger(__name__).warning("No Content-Type header given")
1840 raise ValueError("No Content-Type header given!")
1841
1842 # Boundaries are optional (the FormParser will raise if one is needed
1843 # but not given).
1844 content_type, params = parse_options_header(content_type)
1845 boundary = params.get(b"boundary")
1846
1847 # We need content_type to be a string, not a bytes object.
1848 content_type = content_type.decode("latin-1")
1849
1850 # Instantiate a form parser.
1851 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, config=config)
1852
1853 # Return our parser.
1854 return form_parser
1855
1856
1857def parse_form(
1858 headers: dict[str, bytes],
1859 input_stream: SupportsRead,
1860 on_field: Callable[[Field], None] | None,
1861 on_file: Callable[[File], None] | None,
1862 chunk_size: int = 1048576,
1863) -> None:
1864 """This function is useful if you just want to parse a request body,
1865 without too much work. Pass it a dictionary-like object of the request's
1866 headers, and a file-like object for the input stream, along with two
1867 callbacks that will get called whenever a field or file is parsed.
1868
1869 Args:
1870 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1871 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1872 on_field: Callback to call with each parsed field.
1873 on_file: Callback to call with each parsed file.
1874 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1875 Defaults to 1 MiB.
1876 """
1877 if chunk_size < 1:
1878 raise ValueError(f"chunk_size must be a positive number, not {chunk_size!r}")
1879
1880 # Create our form parser.
1881 parser = create_form_parser(headers, on_field, on_file)
1882
1883 # Read chunks of 1MiB and write to the parser, but never read more than
1884 # the given Content-Length, if any.
1885 content_length: int | float | bytes | None = headers.get("Content-Length")
1886 if content_length is not None:
1887 content_length = int(content_length)
1888 else:
1889 content_length = float("inf")
1890 bytes_read = 0
1891
1892 while True:
1893 # Read only up to the Content-Length given.
1894 max_readable = int(min(content_length - bytes_read, chunk_size))
1895 buff = input_stream.read(max_readable)
1896
1897 # Write to the parser and update our length.
1898 parser.write(buff)
1899 bytes_read += len(buff)
1900
1901 # If we get a buffer that's smaller than the size requested, or if we
1902 # have read up to our content length, we're done.
1903 if len(buff) != max_readable or bytes_read == content_length:
1904 break
1905
1906 # Tell our parser that we're done writing data.
1907 parser.finalize()