1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import sys
7import tempfile
8from email.message import Message
9from enum import IntEnum
10from io import BufferedRandom, BytesIO
11from numbers import Number
12from typing import TYPE_CHECKING, cast
13
14from .decoders import Base64Decoder, QuotedPrintableDecoder
15from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
16
17if TYPE_CHECKING: # pragma: no cover
18 from collections.abc import Callable
19 from typing import Any, Literal, Protocol, TypeAlias, TypedDict
20
21 class SupportsRead(Protocol):
22 def read(self, __n: int) -> bytes: ...
23
24 class QuerystringCallbacks(TypedDict, total=False):
25 on_field_start: Callable[[], None]
26 on_field_name: Callable[[bytes, int, int], None]
27 on_field_data: Callable[[bytes, int, int], None]
28 on_field_end: Callable[[], None]
29 on_end: Callable[[], None]
30
31 class OctetStreamCallbacks(TypedDict, total=False):
32 on_start: Callable[[], None]
33 on_data: Callable[[bytes, int, int], None]
34 on_end: Callable[[], None]
35
36 class MultipartCallbacks(TypedDict, total=False):
37 on_part_begin: Callable[[], None]
38 on_part_data: Callable[[bytes, int, int], None]
39 on_part_end: Callable[[], None]
40 on_header_begin: Callable[[], None]
41 on_header_field: Callable[[bytes, int, int], None]
42 on_header_value: Callable[[bytes, int, int], None]
43 on_header_end: Callable[[], None]
44 on_headers_finished: Callable[[], None]
45 on_end: Callable[[], None]
46
47 class FormParserConfig(TypedDict):
48 UPLOAD_DIR: str | None
49 UPLOAD_KEEP_FILENAME: bool
50 UPLOAD_KEEP_EXTENSIONS: bool
51 UPLOAD_ERROR_ON_BAD_CTE: bool
52 MAX_MEMORY_FILE_SIZE: int
53 MAX_BODY_SIZE: float
54
55 class FileConfig(TypedDict, total=False):
56 UPLOAD_DIR: str | bytes | None
57 UPLOAD_DELETE_TMP: bool
58 UPLOAD_KEEP_FILENAME: bool
59 UPLOAD_KEEP_EXTENSIONS: bool
60 MAX_MEMORY_FILE_SIZE: int
61
62 class _FormProtocol(Protocol):
63 def write(self, data: bytes) -> int: ...
64
65 def finalize(self) -> None: ...
66
67 def close(self) -> None: ...
68
69 class FieldProtocol(_FormProtocol, Protocol):
70 def __init__(self, name: bytes | None) -> None: ...
71
72 def set_none(self) -> None: ...
73
74 class FileProtocol(_FormProtocol, Protocol):
75 def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
76
77 OnFieldCallback = Callable[[FieldProtocol], None]
78 OnFileCallback = Callable[[FileProtocol], None]
79
80 CallbackName: TypeAlias = Literal[
81 "start",
82 "data",
83 "end",
84 "field_start",
85 "field_name",
86 "field_data",
87 "field_end",
88 "part_begin",
89 "part_data",
90 "part_end",
91 "header_begin",
92 "header_field",
93 "header_value",
94 "header_end",
95 "headers_finished",
96 ]
97
98# Unique missing object.
99_missing = object()
100
101
102class QuerystringState(IntEnum):
103 """Querystring parser states.
104
105 These are used to keep track of the state of the parser, and are used to determine
106 what to do when new data is encountered.
107 """
108
109 BEFORE_FIELD = 0
110 FIELD_NAME = 1
111 FIELD_DATA = 2
112
113
114class MultipartState(IntEnum):
115 """Multipart parser states.
116
117 These are used to keep track of the state of the parser, and are used to determine
118 what to do when new data is encountered.
119 """
120
121 START = 0
122 START_BOUNDARY = 1
123 HEADER_FIELD_START = 2
124 HEADER_FIELD = 3
125 HEADER_VALUE_START = 4
126 HEADER_VALUE = 5
127 HEADER_VALUE_ALMOST_DONE = 6
128 HEADERS_ALMOST_DONE = 7
129 PART_DATA_START = 8
130 PART_DATA = 9
131 PART_DATA_END = 10
132 END_BOUNDARY = 11
133 END = 12
134
135
136# Flags for the multipart parser.
137FLAG_PART_BOUNDARY = 1
138FLAG_LAST_BOUNDARY = 2
139
140# Get constants. Since iterating over a str on Python 2 gives you a 1-length
141# string, but iterating over a bytes object on Python 3 gives you an integer,
142# we need to save these constants.
143CR = b"\r"[0]
144LF = b"\n"[0]
145COLON = b":"[0]
146SPACE = b" "[0]
147HYPHEN = b"-"[0]
148AMPERSAND = b"&"[0]
149SEMICOLON = b";"[0]
150LOWER_A = b"a"[0]
151LOWER_Z = b"z"[0]
152NULL = b"\x00"[0]
153
154# fmt: off
155# Mask for ASCII characters that can be http tokens.
156# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
157# and these: !#$%&'*+-.^_`|~
158TOKEN_CHARS_SET = frozenset(
159 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
160 b"abcdefghijklmnopqrstuvwxyz"
161 b"0123456789"
162 b"!#$%&'*+-.^_`|~")
163# fmt: on
164
165
166def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
167 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
168 # Uses email.message.Message to parse the header as described in PEP 594.
169 # Ref: https://peps.python.org/pep-0594/#cgi
170 if not value:
171 return (b"", {})
172
173 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
174 if isinstance(value, bytes): # pragma: no cover
175 value = value.decode("latin-1")
176
177 # For types
178 assert isinstance(value, str), "Value should be a string by now"
179
180 # If we have no options, return the string as-is.
181 if ";" not in value:
182 return (value.lower().strip().encode("latin-1"), {})
183
184 # Split at the first semicolon, to get our value and then options.
185 # ctype, rest = value.split(b';', 1)
186 message = Message()
187 message["content-type"] = value
188 params = message.get_params()
189 # If there were no parameters, this would have already returned above
190 assert params, "At least the content type value should be present"
191 ctype = params.pop(0)[0].encode("latin-1")
192 options: dict[bytes, bytes] = {}
193 for param in params:
194 key, value = param
195 # If the value returned from get_params() is a 3-tuple, the last
196 # element corresponds to the value.
197 # See: https://docs.python.org/3/library/email.compat32-message.html
198 if isinstance(value, tuple):
199 value = value[-1]
200 # If the value is a filename, we need to fix a bug on IE6 that sends
201 # the full file path instead of the filename.
202 if key == "filename":
203 if value[1:3] == ":\\" or value[:2] == "\\\\":
204 value = value.split("\\")[-1]
205 options[key.encode("latin-1")] = value.encode("latin-1")
206 return ctype, options
207
208
209class Field:
210 """A Field object represents a (parsed) form field. It represents a single
211 field with a corresponding name and value.
212
213 The name that a :class:`Field` will be instantiated with is the same name
214 that would be found in the following HTML::
215
216 <input name="name_goes_here" type="text"/>
217
218 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
219 will be called when data is written to the Field, and when the Field is
220 finalized, respectively.
221
222 Args:
223 name: The name of the form field.
224 """
225
226 def __init__(self, name: bytes | None) -> None:
227 self._name = name
228 self._value: list[bytes] = []
229
230 # We cache the joined version of _value for speed.
231 self._cache = _missing
232
233 @classmethod
234 def from_value(cls, name: bytes, value: bytes | None) -> Field:
235 """Create an instance of a :class:`Field`, and set the corresponding
236 value - either None or an actual value. This method will also
237 finalize the Field itself.
238
239 Args:
240 name: the name of the form field.
241 value: the value of the form field - either a bytestring or None.
242
243 Returns:
244 A new instance of a [`Field`][python_multipart.Field].
245 """
246
247 f = cls(name)
248 if value is None:
249 f.set_none()
250 else:
251 f.write(value)
252 f.finalize()
253 return f
254
255 def write(self, data: bytes) -> int:
256 """Write some data into the form field.
257
258 Args:
259 data: The data to write to the field.
260
261 Returns:
262 The number of bytes written.
263 """
264 return self.on_data(data)
265
266 def on_data(self, data: bytes) -> int:
267 """This method is a callback that will be called whenever data is
268 written to the Field.
269
270 Args:
271 data: The data to write to the field.
272
273 Returns:
274 The number of bytes written.
275 """
276 self._value.append(data)
277 self._cache = _missing
278 return len(data)
279
280 def on_end(self) -> None:
281 """This method is called whenever the Field is finalized."""
282 if self._cache is _missing:
283 self._cache = b"".join(self._value)
284
285 def finalize(self) -> None:
286 """Finalize the form field."""
287 self.on_end()
288
289 def close(self) -> None:
290 """Close the Field object. This will free any underlying cache."""
291 # Free our value array.
292 if self._cache is _missing:
293 self._cache = b"".join(self._value)
294
295 del self._value
296
297 def set_none(self) -> None:
298 """Some fields in a querystring can possibly have a value of None - for
299 example, the string "foo&bar=&baz=asdf" will have a field with the
300 name "foo" and value None, one with name "bar" and value "", and one
301 with name "baz" and value "asdf". Since the write() interface doesn't
302 support writing None, this function will set the field value to None.
303 """
304 self._cache = None
305
306 @property
307 def field_name(self) -> bytes | None:
308 """This property returns the name of the field."""
309 return self._name
310
311 @property
312 def value(self) -> bytes | None:
313 """This property returns the value of the form field."""
314 if self._cache is _missing:
315 self._cache = b"".join(self._value)
316
317 assert isinstance(self._cache, bytes) or self._cache is None
318 return self._cache
319
320 def __eq__(self, other: object) -> bool:
321 if isinstance(other, Field):
322 return self.field_name == other.field_name and self.value == other.value
323 else:
324 return NotImplemented
325
326 def __repr__(self) -> str:
327 if self.value is not None and len(self.value) > 97:
328 # We get the repr, and then insert three dots before the final
329 # quote.
330 v = repr(self.value[:97])[:-1] + "...'"
331 else:
332 v = repr(self.value)
333
334 return f"{self.__class__.__name__}(field_name={self.field_name!r}, value={v})"
335
336
337class File:
338 """This class represents an uploaded file. It handles writing file data to
339 either an in-memory file or a temporary file on-disk, if the optional
340 threshold is passed.
341
342 There are some options that can be passed to the File to change behavior
343 of the class. Valid options are as follows:
344
345 | Name | Type | Default | Description |
346 |-----------------------|-------|---------|-------------|
347 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
348 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
349 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
350 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
351 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
352
353 Args:
354 file_name: The name of the file that this [`File`][python_multipart.File] represents.
355 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
356 the file was uploaded with Content-Type application/octet-stream.
357 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
358 """ # noqa: E501
359
360 def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
361 # Save configuration, set other variables default.
362 self.logger = logging.getLogger(__name__)
363 self._config = config
364 self._in_memory = True
365 self._bytes_written = 0
366 self._fileobj: BytesIO | BufferedRandom = BytesIO()
367
368 # Save the provided field/file name.
369 self._field_name = field_name
370 self._file_name = file_name
371
372 # Our actual file name is None by default, since, depending on our
373 # config, we may not actually use the provided name.
374 self._actual_file_name: bytes | None = None
375
376 # Split the extension from the filename.
377 if file_name is not None:
378 base, ext = os.path.splitext(file_name)
379 self._file_base = base
380 self._ext = ext
381
382 @property
383 def field_name(self) -> bytes | None:
384 """The form field associated with this file. May be None if there isn't
385 one, for example when we have an application/octet-stream upload.
386 """
387 return self._field_name
388
389 @property
390 def file_name(self) -> bytes | None:
391 """The file name given in the upload request."""
392 return self._file_name
393
394 @property
395 def actual_file_name(self) -> bytes | None:
396 """The file name that this file is saved as. Will be None if it's not
397 currently saved on disk.
398 """
399 return self._actual_file_name
400
401 @property
402 def file_object(self) -> BytesIO | BufferedRandom:
403 """The file object that we're currently writing to. Note that this
404 will either be an instance of a :class:`io.BytesIO`, or a regular file
405 object.
406 """
407 return self._fileobj
408
409 @property
410 def size(self) -> int:
411 """The total size of this file, counted as the number of bytes that
412 currently have been written to the file.
413 """
414 return self._bytes_written
415
416 @property
417 def in_memory(self) -> bool:
418 """A boolean representing whether or not this file object is currently
419 stored in-memory or on-disk.
420 """
421 return self._in_memory
422
423 def flush_to_disk(self) -> None:
424 """If the file is already on-disk, do nothing. Otherwise, copy from
425 the in-memory buffer to a disk file, and then reassign our internal
426 file object to this new disk file.
427
428 Note that if you attempt to flush a file that is already on-disk, a
429 warning will be logged to this module's logger.
430 """
431 if not self._in_memory:
432 self.logger.warning("Trying to flush to disk when we're not in memory")
433 return
434
435 # Go back to the start of our file.
436 self._fileobj.seek(0)
437
438 # Open a new file.
439 new_file = self._get_disk_file()
440
441 # Copy the file objects.
442 shutil.copyfileobj(self._fileobj, new_file)
443
444 # Seek to the new position in our new file.
445 new_file.seek(self._bytes_written)
446
447 # Reassign the fileobject.
448 old_fileobj = self._fileobj
449 self._fileobj = new_file
450
451 # We're no longer in memory.
452 self._in_memory = False
453
454 # Close the old file object.
455 old_fileobj.close()
456
457 def _get_disk_file(self) -> BufferedRandom:
458 """This function is responsible for getting a file object on-disk for us."""
459 self.logger.info("Opening a file on disk")
460
461 file_dir = self._config.get("UPLOAD_DIR")
462 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
463 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
464 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
465 tmp_file: None | BufferedRandom = None
466
467 # If we have a directory and are to keep the filename...
468 if file_dir is not None and keep_filename:
469 self.logger.info("Saving with filename in: %r", file_dir)
470
471 # Build our filename.
472 # TODO: what happens if we don't have a filename?
473 fname = self._file_base + self._ext if keep_extensions else self._file_base
474
475 path = os.path.join(file_dir, fname) # type: ignore[arg-type]
476 try:
477 self.logger.info("Opening file: %r", path)
478 tmp_file = open(path, "w+b")
479 except OSError:
480 tmp_file = None
481
482 self.logger.exception("Error opening temporary file")
483 raise FileError("Error opening temporary file: %r" % path)
484 else:
485 # Build options array.
486 # Note that on Python 3, tempfile doesn't support byte names. We
487 # encode our paths using the default filesystem encoding.
488 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
489
490 if file_dir is None:
491 dir = None
492 elif isinstance(file_dir, bytes):
493 dir = file_dir.decode(sys.getfilesystemencoding())
494 else:
495 dir = file_dir # pragma: no cover
496
497 # Create a temporary (named) file with the appropriate settings.
498 self.logger.info(
499 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
500 )
501 try:
502 tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
503 except OSError:
504 self.logger.exception("Error creating named temporary file")
505 raise FileError("Error creating named temporary file")
506
507 assert tmp_file is not None
508 # Encode filename as bytes.
509 if isinstance(tmp_file.name, str):
510 fname = tmp_file.name.encode(sys.getfilesystemencoding())
511 else:
512 fname = cast(bytes, tmp_file.name) # pragma: no cover
513
514 self._actual_file_name = fname
515 return tmp_file
516
517 def write(self, data: bytes) -> int:
518 """Write some data to the File.
519
520 :param data: a bytestring
521 """
522 return self.on_data(data)
523
524 def on_data(self, data: bytes) -> int:
525 """This method is a callback that will be called whenever data is
526 written to the File.
527
528 Args:
529 data: The data to write to the file.
530
531 Returns:
532 The number of bytes written.
533 """
534 bwritten = self._fileobj.write(data)
535
536 # If the bytes written isn't the same as the length, just return.
537 if bwritten != len(data):
538 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
539 return bwritten
540
541 # Keep track of how many bytes we've written.
542 self._bytes_written += bwritten
543
544 # If we're in-memory and are over our limit, we create a file.
545 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
546 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
547 self.logger.info("Flushing to disk")
548 self.flush_to_disk()
549
550 # Return the number of bytes written.
551 return bwritten
552
553 def on_end(self) -> None:
554 """This method is called whenever the Field is finalized."""
555 # Flush the underlying file object
556 self._fileobj.flush()
557
558 def finalize(self) -> None:
559 """Finalize the form file. This will not close the underlying file,
560 but simply signal that we are finished writing to the File.
561 """
562 self.on_end()
563
564 def close(self) -> None:
565 """Close the File object. This will actually close the underlying
566 file object (whether it's a :class:`io.BytesIO` or an actual file
567 object).
568 """
569 self._fileobj.close()
570
571 def __repr__(self) -> str:
572 return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})"
573
574
575class BaseParser:
576 """This class is the base class for all parsers. It contains the logic for
577 calling and adding callbacks.
578
579 A callback can be one of two different forms. "Notification callbacks" are
580 callbacks that are called when something happens - for example, when a new
581 part of a multipart message is encountered by the parser. "Data callbacks"
582 are called when we get some sort of data - for example, part of the body of
583 a multipart chunk. Notification callbacks are called with no parameters,
584 whereas data callbacks are called with three, as follows::
585
586 data_callback(data, start, end)
587
588 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
589 Python 3). "start" and "end" are integer indexes into the "data" string
590 that represent the data of interest. Thus, in a data callback, the slice
591 `data[start:end]` represents the data that the callback is "interested in".
592 The callback is not passed a copy of the data, since copying severely hurts
593 performance.
594 """
595
596 def __init__(self) -> None:
597 self.logger = logging.getLogger(__name__)
598 self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
599
600 def callback(
601 self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
602 ) -> None:
603 """This function calls a provided callback with some data. If the
604 callback is not set, will do nothing.
605
606 Args:
607 name: The name of the callback to call (as a string).
608 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
609 callback, and no parameters are given.
610 end: An integer that is passed to the data callback.
611 start: An integer that is passed to the data callback.
612 """
613 on_name = "on_" + name
614 func = self.callbacks.get(on_name)
615 if func is None:
616 return
617 func = cast("Callable[..., Any]", func)
618 # Depending on whether we're given a buffer...
619 if data is not None:
620 # Don't do anything if we have start == end.
621 if start is not None and start == end:
622 return
623
624 self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
625 func(data, start, end)
626 else:
627 self.logger.debug("Calling %s with no data", on_name)
628 func()
629
630 def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
631 """Update the function for a callback. Removes from the callbacks dict
632 if new_func is None.
633
634 :param name: The name of the callback to call (as a string).
635
636 :param new_func: The new function for the callback. If None, then the
637 callback will be removed (with no error if it does not
638 exist).
639 """
640 if new_func is None:
641 self.callbacks.pop("on_" + name, None) # type: ignore[misc]
642 else:
643 self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
644
645 def close(self) -> None:
646 pass # pragma: no cover
647
648 def finalize(self) -> None:
649 pass # pragma: no cover
650
651 def __repr__(self) -> str:
652 return "%s()" % self.__class__.__name__
653
654
655class OctetStreamParser(BaseParser):
656 """This parser parses an octet-stream request body and calls callbacks when
657 incoming data is received. Callbacks are as follows:
658
659 | Callback Name | Parameters | Description |
660 |----------------|-----------------|-----------------------------------------------------|
661 | on_start | None | Called when the first data is parsed. |
662 | on_data | data, start, end| Called for each data chunk that is parsed. |
663 | on_end | None | Called when the parser is finished parsing all data.|
664
665 Args:
666 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
667 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
668 """
669
670 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
671 super().__init__()
672 self.callbacks = callbacks
673 self._started = False
674
675 if not isinstance(max_size, Number) or max_size < 1:
676 raise ValueError("max_size must be a positive number, not %r" % max_size)
677 self.max_size: int | float = max_size
678 self._current_size = 0
679
680 def write(self, data: bytes) -> int:
681 """Write some data to the parser, which will perform size verification,
682 and then pass the data to the underlying callback.
683
684 Args:
685 data: The data to write to the parser.
686
687 Returns:
688 The number of bytes written.
689 """
690 if not self._started:
691 self.callback("start")
692 self._started = True
693
694 # Truncate data length.
695 data_len = len(data)
696 if (self._current_size + data_len) > self.max_size:
697 # We truncate the length of data that we are to process.
698 new_size = int(self.max_size - self._current_size)
699 self.logger.warning(
700 "Current size is %d (max %d), so truncating data length from %d to %d",
701 self._current_size,
702 self.max_size,
703 data_len,
704 new_size,
705 )
706 data_len = new_size
707
708 # Increment size, then callback, in case there's an exception.
709 self._current_size += data_len
710 self.callback("data", data, 0, data_len)
711 return data_len
712
713 def finalize(self) -> None:
714 """Finalize this parser, which signals to that we are finished parsing,
715 and sends the on_end callback.
716 """
717 self.callback("end")
718
719 def __repr__(self) -> str:
720 return "%s()" % self.__class__.__name__
721
722
723class QuerystringParser(BaseParser):
724 """This is a streaming querystring parser. It will consume data, and call
725 the callbacks given when it has data.
726
727 | Callback Name | Parameters | Description |
728 |----------------|-----------------|-----------------------------------------------------|
729 | on_field_start | None | Called when a new field is encountered. |
730 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
731 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
732 | on_field_end | None | Called when the end of a field is encountered. |
733 | on_end | None | Called when the parser is finished parsing all data.|
734
735 Args:
736 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
737 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
738 behavior of the parser changes as the following: if a field has a value with an equal sign
739 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
740 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
741 then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
742 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
743 """ # noqa: E501
744
745 state: QuerystringState
746
747 def __init__(
748 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
749 ) -> None:
750 super().__init__()
751 self.state = QuerystringState.BEFORE_FIELD
752 self._found_sep = False
753
754 self.callbacks = callbacks
755
756 # Max-size stuff
757 if not isinstance(max_size, Number) or max_size < 1:
758 raise ValueError("max_size must be a positive number, not %r" % max_size)
759 self.max_size: int | float = max_size
760 self._current_size = 0
761
762 # Should parsing be strict?
763 self.strict_parsing = strict_parsing
764
765 def write(self, data: bytes) -> int:
766 """Write some data to the parser, which will perform size verification,
767 parse into either a field name or value, and then pass the
768 corresponding data to the underlying callback. If an error is
769 encountered while parsing, a QuerystringParseError will be raised. The
770 "offset" attribute of the raised exception will be set to the offset in
771 the input data chunk (NOT the overall stream) that caused the error.
772
773 Args:
774 data: The data to write to the parser.
775
776 Returns:
777 The number of bytes written.
778 """
779 # Handle sizing.
780 data_len = len(data)
781 if (self._current_size + data_len) > self.max_size:
782 # We truncate the length of data that we are to process.
783 new_size = int(self.max_size - self._current_size)
784 self.logger.warning(
785 "Current size is %d (max %d), so truncating data length from %d to %d",
786 self._current_size,
787 self.max_size,
788 data_len,
789 new_size,
790 )
791 data_len = new_size
792
793 l = 0
794 try:
795 l = self._internal_write(data, data_len)
796 finally:
797 self._current_size += l
798
799 return l
800
801 def _internal_write(self, data: bytes, length: int) -> int:
802 state = self.state
803 strict_parsing = self.strict_parsing
804 found_sep = self._found_sep
805
806 i = 0
807 while i < length:
808 ch = data[i]
809
810 # Depending on our state...
811 if state == QuerystringState.BEFORE_FIELD:
812 # If the 'found_sep' flag is set, we've already encountered
813 # and skipped a single separator. If so, we check our strict
814 # parsing flag and decide what to do. Otherwise, we haven't
815 # yet reached a separator, and thus, if we do, we need to skip
816 # it as it will be the boundary between fields that's supposed
817 # to be there.
818 if ch == AMPERSAND or ch == SEMICOLON:
819 if found_sep:
820 # If we're parsing strictly, we disallow blank chunks.
821 if strict_parsing:
822 e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i)
823 e.offset = i
824 raise e
825 else:
826 self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
827 else:
828 # This case is when we're skipping the (first)
829 # separator between fields, so we just set our flag
830 # and continue on.
831 found_sep = True
832 else:
833 # Emit a field-start event, and go to that state. Also,
834 # reset the "found_sep" flag, for the next time we get to
835 # this state.
836 self.callback("field_start")
837 i -= 1
838 state = QuerystringState.FIELD_NAME
839 found_sep = False
840
841 elif state == QuerystringState.FIELD_NAME:
842 # Try and find a separator - we ensure that, if we do, we only
843 # look for the equal sign before it.
844 sep_pos = data.find(b"&", i)
845 if sep_pos == -1:
846 sep_pos = data.find(b";", i)
847
848 # See if we can find an equals sign in the remaining data. If
849 # so, we can immediately emit the field name and jump to the
850 # data state.
851 if sep_pos != -1:
852 equals_pos = data.find(b"=", i, sep_pos)
853 else:
854 equals_pos = data.find(b"=", i)
855
856 if equals_pos != -1:
857 # Emit this name.
858 self.callback("field_name", data, i, equals_pos)
859
860 # Jump i to this position. Note that it will then have 1
861 # added to it below, which means the next iteration of this
862 # loop will inspect the character after the equals sign.
863 i = equals_pos
864 state = QuerystringState.FIELD_DATA
865 else:
866 # No equals sign found.
867 if not strict_parsing:
868 # See also comments in the QuerystringState.FIELD_DATA case below.
869 # If we found the separator, we emit the name and just
870 # end - there's no data callback at all (not even with
871 # a blank value).
872 if sep_pos != -1:
873 self.callback("field_name", data, i, sep_pos)
874 self.callback("field_end")
875
876 i = sep_pos - 1
877 state = QuerystringState.BEFORE_FIELD
878 else:
879 # Otherwise, no separator in this block, so the
880 # rest of this chunk must be a name.
881 self.callback("field_name", data, i, length)
882 i = length
883
884 else:
885 # We're parsing strictly. If we find a separator,
886 # this is an error - we require an equals sign.
887 if sep_pos != -1:
888 e = QuerystringParseError(
889 "When strict_parsing is True, we require an "
890 "equals sign in all field chunks. Did not "
891 "find one in the chunk that starts at %d" % (i,)
892 )
893 e.offset = i
894 raise e
895
896 # No separator in the rest of this chunk, so it's just
897 # a field name.
898 self.callback("field_name", data, i, length)
899 i = length
900
901 elif state == QuerystringState.FIELD_DATA:
902 # Try finding either an ampersand or a semicolon after this
903 # position.
904 sep_pos = data.find(b"&", i)
905 if sep_pos == -1:
906 sep_pos = data.find(b";", i)
907
908 # If we found it, callback this bit as data and then go back
909 # to expecting to find a field.
910 if sep_pos != -1:
911 self.callback("field_data", data, i, sep_pos)
912 self.callback("field_end")
913
914 # Note that we go to the separator, which brings us to the
915 # "before field" state. This allows us to properly emit
916 # "field_start" events only when we actually have data for
917 # a field of some sort.
918 i = sep_pos - 1
919 state = QuerystringState.BEFORE_FIELD
920
921 # Otherwise, emit the rest as data and finish.
922 else:
923 self.callback("field_data", data, i, length)
924 i = length
925
926 else: # pragma: no cover (error case)
927 msg = "Reached an unknown state %d at %d" % (state, i)
928 self.logger.warning(msg)
929 e = QuerystringParseError(msg)
930 e.offset = i
931 raise e
932
933 i += 1
934
935 self.state = state
936 self._found_sep = found_sep
937 return len(data)
938
939 def finalize(self) -> None:
940 """Finalize this parser, which signals to that we are finished parsing,
941 if we're still in the middle of a field, an on_field_end callback, and
942 then the on_end callback.
943 """
944 # If we're currently in the middle of a field, we finish it.
945 if self.state == QuerystringState.FIELD_DATA:
946 self.callback("field_end")
947 self.callback("end")
948
949 def __repr__(self) -> str:
950 return "{}(strict_parsing={!r}, max_size={!r})".format(
951 self.__class__.__name__, self.strict_parsing, self.max_size
952 )
953
954
955class MultipartParser(BaseParser):
956 """This class is a streaming multipart/form-data parser.
957
958 | Callback Name | Parameters | Description |
959 |--------------------|-----------------|-------------|
960 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
961 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
962 | on_part_end | None | Called when the end of a part is reached. |
963 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
964 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
965 | on_header_value | data, start, end| Called when we get data for a header. |
966 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
967 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
968 | on_end | None | Called when the parser is finished parsing all data. |
969
970 Args:
971 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
972 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
973 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
974 """ # noqa: E501
975
976 def __init__(
977 self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
978 ) -> None:
979 # Initialize parser state.
980 super().__init__()
981 self.state = MultipartState.START
982 self.index = self.flags = 0
983
984 self.callbacks = callbacks
985
986 if not isinstance(max_size, Number) or max_size < 1:
987 raise ValueError("max_size must be a positive number, not %r" % max_size)
988 self.max_size = max_size
989 self._current_size = 0
990
991 # Setup marks. These are used to track the state of data received.
992 self.marks: dict[str, int] = {}
993
994 # Save our boundary.
995 if isinstance(boundary, str): # pragma: no cover
996 boundary = boundary.encode("latin-1")
997 self.boundary = b"\r\n--" + boundary
998
999 def write(self, data: bytes) -> int:
1000 """Write some data to the parser, which will perform size verification,
1001 and then parse the data into the appropriate location (e.g. header,
1002 data, etc.), and pass this on to the underlying callback. If an error
1003 is encountered, a MultipartParseError will be raised. The "offset"
1004 attribute on the raised exception will be set to the offset of the byte
1005 in the input chunk that caused the error.
1006
1007 Args:
1008 data: The data to write to the parser.
1009
1010 Returns:
1011 The number of bytes written.
1012 """
1013 # Handle sizing.
1014 data_len = len(data)
1015 if (self._current_size + data_len) > self.max_size:
1016 # We truncate the length of data that we are to process.
1017 new_size = int(self.max_size - self._current_size)
1018 self.logger.warning(
1019 "Current size is %d (max %d), so truncating data length from %d to %d",
1020 self._current_size,
1021 self.max_size,
1022 data_len,
1023 new_size,
1024 )
1025 data_len = new_size
1026
1027 l = 0
1028 try:
1029 l = self._internal_write(data, data_len)
1030 finally:
1031 self._current_size += l
1032
1033 return l
1034
1035 def _internal_write(self, data: bytes, length: int) -> int:
1036 # Get values from locals.
1037 boundary = self.boundary
1038
1039 # Get our state, flags and index. These are persisted between calls to
1040 # this function.
1041 state = self.state
1042 index = self.index
1043 flags = self.flags
1044
1045 # Our index defaults to 0.
1046 i = 0
1047
1048 # Set a mark.
1049 def set_mark(name: str) -> None:
1050 self.marks[name] = i
1051
1052 # Remove a mark.
1053 def delete_mark(name: str, reset: bool = False) -> None:
1054 self.marks.pop(name, None)
1055
1056 # Helper function that makes calling a callback with data easier. The
1057 # 'remaining' parameter will callback from the marked value until the
1058 # end of the buffer, and reset the mark, instead of deleting it. This
1059 # is used at the end of the function to call our callbacks with any
1060 # remaining data in this chunk.
1061 def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
1062 marked_index = self.marks.get(name)
1063 if marked_index is None:
1064 return
1065
1066 # Otherwise, we call it from the mark to the current byte we're
1067 # processing.
1068 if end_i <= marked_index:
1069 # There is no additional data to send.
1070 pass
1071 elif marked_index >= 0:
1072 # We are emitting data from the local buffer.
1073 self.callback(name, data, marked_index, end_i)
1074 else:
1075 # Some of the data comes from a partial boundary match.
1076 # and requires look-behind.
1077 # We need to use self.flags (and not flags) because we care about
1078 # the state when we entered the loop.
1079 lookbehind_len = -marked_index
1080 if lookbehind_len <= len(boundary):
1081 self.callback(name, boundary, 0, lookbehind_len)
1082 elif self.flags & FLAG_PART_BOUNDARY:
1083 lookback = boundary + b"\r\n"
1084 self.callback(name, lookback, 0, lookbehind_len)
1085 elif self.flags & FLAG_LAST_BOUNDARY:
1086 lookback = boundary + b"--\r\n"
1087 self.callback(name, lookback, 0, lookbehind_len)
1088 else: # pragma: no cover (error case)
1089 self.logger.warning("Look-back buffer error")
1090
1091 if end_i > 0:
1092 self.callback(name, data, 0, end_i)
1093 # If we're getting remaining data, we have got all the data we
1094 # can be certain is not a boundary, leaving only a partial boundary match.
1095 if remaining:
1096 self.marks[name] = end_i - length
1097 else:
1098 self.marks.pop(name, None)
1099
1100 # For each byte...
1101 while i < length:
1102 c = data[i]
1103
1104 if state == MultipartState.START:
1105 # Skip leading newlines
1106 if c == CR or c == LF:
1107 i += 1
1108 continue
1109
1110 # index is used as in index into our boundary. Set to 0.
1111 index = 0
1112
1113 # Move to the next state, but decrement i so that we re-process
1114 # this character.
1115 state = MultipartState.START_BOUNDARY
1116 i -= 1
1117
1118 elif state == MultipartState.START_BOUNDARY:
1119 # Check to ensure that the last 2 characters in our boundary
1120 # are CRLF.
1121 if index == len(boundary) - 2:
1122 if c == HYPHEN:
1123 # Potential empty message.
1124 state = MultipartState.END_BOUNDARY
1125 elif c != CR:
1126 # Error!
1127 msg = "Did not find CR at end of boundary (%d)" % (i,)
1128 self.logger.warning(msg)
1129 e = MultipartParseError(msg)
1130 e.offset = i
1131 raise e
1132
1133 index += 1
1134
1135 elif index == len(boundary) - 2 + 1:
1136 if c != LF:
1137 msg = "Did not find LF at end of boundary (%d)" % (i,)
1138 self.logger.warning(msg)
1139 e = MultipartParseError(msg)
1140 e.offset = i
1141 raise e
1142
1143 # The index is now used for indexing into our boundary.
1144 index = 0
1145
1146 # Callback for the start of a part.
1147 self.callback("part_begin")
1148
1149 # Move to the next character and state.
1150 state = MultipartState.HEADER_FIELD_START
1151
1152 else:
1153 # Check to ensure our boundary matches
1154 if c != boundary[index + 2]:
1155 msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
1156 self.logger.warning(msg)
1157 e = MultipartParseError(msg)
1158 e.offset = i
1159 raise e
1160
1161 # Increment index into boundary and continue.
1162 index += 1
1163
1164 elif state == MultipartState.HEADER_FIELD_START:
1165 # Mark the start of a header field here, reset the index, and
1166 # continue parsing our header field.
1167 index = 0
1168
1169 # Set a mark of our header field.
1170 set_mark("header_field")
1171
1172 # Notify that we're starting a header if the next character is
1173 # not a CR; a CR at the beginning of the header will cause us
1174 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1175 # below.
1176 if c != CR:
1177 self.callback("header_begin")
1178
1179 # Move to parsing header fields.
1180 state = MultipartState.HEADER_FIELD
1181 i -= 1
1182
1183 elif state == MultipartState.HEADER_FIELD:
1184 # If we've reached a CR at the beginning of a header, it means
1185 # that we've reached the second of 2 newlines, and so there are
1186 # no more headers to parse.
1187 if c == CR and index == 0:
1188 delete_mark("header_field")
1189 state = MultipartState.HEADERS_ALMOST_DONE
1190 i += 1
1191 continue
1192
1193 # Increment our index in the header.
1194 index += 1
1195
1196 # If we've reached a colon, we're done with this header.
1197 if c == COLON:
1198 # A 0-length header is an error.
1199 if index == 1:
1200 msg = "Found 0-length header at %d" % (i,)
1201 self.logger.warning(msg)
1202 e = MultipartParseError(msg)
1203 e.offset = i
1204 raise e
1205
1206 # Call our callback with the header field.
1207 data_callback("header_field", i)
1208
1209 # Move to parsing the header value.
1210 state = MultipartState.HEADER_VALUE_START
1211
1212 elif c not in TOKEN_CHARS_SET:
1213 msg = "Found invalid character %r in header at %d" % (c, i)
1214 self.logger.warning(msg)
1215 e = MultipartParseError(msg)
1216 e.offset = i
1217 raise e
1218
1219 elif state == MultipartState.HEADER_VALUE_START:
1220 # Skip leading spaces.
1221 if c == SPACE:
1222 i += 1
1223 continue
1224
1225 # Mark the start of the header value.
1226 set_mark("header_value")
1227
1228 # Move to the header-value state, reprocessing this character.
1229 state = MultipartState.HEADER_VALUE
1230 i -= 1
1231
1232 elif state == MultipartState.HEADER_VALUE:
1233 # If we've got a CR, we're nearly done our headers. Otherwise,
1234 # we do nothing and just move past this character.
1235 if c == CR:
1236 data_callback("header_value", i)
1237 self.callback("header_end")
1238 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1239
1240 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1241 # The last character should be a LF. If not, it's an error.
1242 if c != LF:
1243 msg = f"Did not find LF character at end of header (found {c!r})"
1244 self.logger.warning(msg)
1245 e = MultipartParseError(msg)
1246 e.offset = i
1247 raise e
1248
1249 # Move back to the start of another header. Note that if that
1250 # state detects ANOTHER newline, it'll trigger the end of our
1251 # headers.
1252 state = MultipartState.HEADER_FIELD_START
1253
1254 elif state == MultipartState.HEADERS_ALMOST_DONE:
1255 # We're almost done our headers. This is reached when we parse
1256 # a CR at the beginning of a header, so our next character
1257 # should be a LF, or it's an error.
1258 if c != LF:
1259 msg = f"Did not find LF at end of headers (found {c!r})"
1260 self.logger.warning(msg)
1261 e = MultipartParseError(msg)
1262 e.offset = i
1263 raise e
1264
1265 self.callback("headers_finished")
1266 state = MultipartState.PART_DATA_START
1267
1268 elif state == MultipartState.PART_DATA_START:
1269 # Mark the start of our part data.
1270 set_mark("part_data")
1271
1272 # Start processing part data, including this character.
1273 state = MultipartState.PART_DATA
1274 i -= 1
1275
1276 elif state == MultipartState.PART_DATA:
1277 # We're processing our part data right now. During this, we
1278 # need to efficiently search for our boundary, since any data
1279 # on any number of lines can be a part of the current data.
1280
1281 # Save the current value of our index. We use this in case we
1282 # find part of a boundary, but it doesn't match fully.
1283 prev_index = index
1284
1285 # Set up variables.
1286 boundary_length = len(boundary)
1287 data_length = length
1288
1289 # If our index is 0, we're starting a new part, so start our
1290 # search.
1291 if index == 0:
1292 # The most common case is likely to be that the whole
1293 # boundary is present in the buffer.
1294 # Calling `find` is much faster than iterating here.
1295 i0 = data.find(boundary, i, data_length)
1296 if i0 >= 0:
1297 # We matched the whole boundary string.
1298 index = boundary_length - 1
1299 i = i0 + boundary_length - 1
1300 else:
1301 # No match found for whole string.
1302 # There may be a partial boundary at the end of the
1303 # data, which the find will not match.
1304 # Since the length should to be searched is limited to
1305 # the boundary length, just perform a naive search.
1306 i = max(i, data_length - boundary_length)
1307
1308 # Search forward until we either hit the end of our buffer,
1309 # or reach a potential start of the boundary.
1310 while i < data_length - 1 and data[i] != boundary[0]:
1311 i += 1
1312
1313 c = data[i]
1314
1315 # Now, we have a couple of cases here. If our index is before
1316 # the end of the boundary...
1317 if index < boundary_length:
1318 # If the character matches...
1319 if boundary[index] == c:
1320 # The current character matches, so continue!
1321 index += 1
1322 else:
1323 index = 0
1324
1325 # Our index is equal to the length of our boundary!
1326 elif index == boundary_length:
1327 # First we increment it.
1328 index += 1
1329
1330 # Now, if we've reached a newline, we need to set this as
1331 # the potential end of our boundary.
1332 if c == CR:
1333 flags |= FLAG_PART_BOUNDARY
1334
1335 # Otherwise, if this is a hyphen, we might be at the last
1336 # of all boundaries.
1337 elif c == HYPHEN:
1338 flags |= FLAG_LAST_BOUNDARY
1339
1340 # Otherwise, we reset our index, since this isn't either a
1341 # newline or a hyphen.
1342 else:
1343 index = 0
1344
1345 # Our index is right after the part boundary, which should be
1346 # a LF.
1347 elif index == boundary_length + 1:
1348 # If we're at a part boundary (i.e. we've seen a CR
1349 # character already)...
1350 if flags & FLAG_PART_BOUNDARY:
1351 # We need a LF character next.
1352 if c == LF:
1353 # Unset the part boundary flag.
1354 flags &= ~FLAG_PART_BOUNDARY
1355
1356 # We have identified a boundary, callback for any data before it.
1357 data_callback("part_data", i - index)
1358 # Callback indicating that we've reached the end of
1359 # a part, and are starting a new one.
1360 self.callback("part_end")
1361 self.callback("part_begin")
1362
1363 # Move to parsing new headers.
1364 index = 0
1365 state = MultipartState.HEADER_FIELD_START
1366 i += 1
1367 continue
1368
1369 # We didn't find an LF character, so no match. Reset
1370 # our index and clear our flag.
1371 index = 0
1372 flags &= ~FLAG_PART_BOUNDARY
1373
1374 # Otherwise, if we're at the last boundary (i.e. we've
1375 # seen a hyphen already)...
1376 elif flags & FLAG_LAST_BOUNDARY:
1377 # We need a second hyphen here.
1378 if c == HYPHEN:
1379 # We have identified a boundary, callback for any data before it.
1380 data_callback("part_data", i - index)
1381 # Callback to end the current part, and then the
1382 # message.
1383 self.callback("part_end")
1384 self.callback("end")
1385 state = MultipartState.END
1386 else:
1387 # No match, so reset index.
1388 index = 0
1389
1390 # Otherwise, our index is 0. If the previous index is not, it
1391 # means we reset something, and we need to take the data we
1392 # thought was part of our boundary and send it along as actual
1393 # data.
1394 if index == 0 and prev_index > 0:
1395 # Overwrite our previous index.
1396 prev_index = 0
1397
1398 # Re-consider the current character, since this could be
1399 # the start of the boundary itself.
1400 i -= 1
1401
1402 elif state == MultipartState.END_BOUNDARY:
1403 if index == len(boundary) - 2 + 1:
1404 if c != HYPHEN:
1405 msg = "Did not find - at end of boundary (%d)" % (i,)
1406 self.logger.warning(msg)
1407 e = MultipartParseError(msg)
1408 e.offset = i
1409 raise e
1410 index += 1
1411 self.callback("end")
1412 state = MultipartState.END
1413
1414 elif state == MultipartState.END:
1415 # Don't do anything if chunk ends with CRLF.
1416 if c == CR and i + 1 < length and data[i + 1] == LF:
1417 i += 2
1418 continue
1419 # Skip data after the last boundary.
1420 self.logger.warning("Skipping data after last boundary")
1421 i = length
1422 break
1423
1424 else: # pragma: no cover (error case)
1425 # We got into a strange state somehow! Just stop processing.
1426 msg = "Reached an unknown state %d at %d" % (state, i)
1427 self.logger.warning(msg)
1428 e = MultipartParseError(msg)
1429 e.offset = i
1430 raise e
1431
1432 # Move to the next byte.
1433 i += 1
1434
1435 # We call our callbacks with any remaining data. Note that we pass
1436 # the 'remaining' flag, which sets the mark back to 0 instead of
1437 # deleting it, if it's found. This is because, if the mark is found
1438 # at this point, we assume that there's data for one of these things
1439 # that has been parsed, but not yet emitted. And, as such, it implies
1440 # that we haven't yet reached the end of this 'thing'. So, by setting
1441 # the mark to 0, we cause any data callbacks that take place in future
1442 # calls to this function to start from the beginning of that buffer.
1443 data_callback("header_field", length, True)
1444 data_callback("header_value", length, True)
1445 data_callback("part_data", length - index, True)
1446
1447 # Save values to locals.
1448 self.state = state
1449 self.index = index
1450 self.flags = flags
1451
1452 # Return our data length to indicate no errors, and that we processed
1453 # all of it.
1454 return length
1455
1456 def finalize(self) -> None:
1457 """Finalize this parser, which signals to that we are finished parsing.
1458
1459 Note: It does not currently, but in the future, it will verify that we
1460 are in the final state of the parser (i.e. the end of the multipart
1461 message is well-formed), and, if not, throw an error.
1462 """
1463 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1464 # error or otherwise state that we're not finished parsing.
1465 pass
1466
1467 def __repr__(self) -> str:
1468 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1469
1470
1471class FormParser:
1472 """This class is the all-in-one form parser. Given all the information
1473 necessary to parse a form, it will instantiate the correct parser, create
1474 the proper :class:`Field` and :class:`File` classes to store the data that
1475 is parsed, and call the two given callbacks with each field and file as
1476 they become available.
1477
1478 Args:
1479 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1480 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1481 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1482 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1483 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1484 in the Content-Type header, as a bytestring.
1485 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1486 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1487 manually.
1488 FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class
1489 if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and
1490 it must provide the following functions::
1491 - file_instance.write(data)
1492 - file_instance.finalize()
1493 - file_instance.close()
1494 FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own
1495 class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it
1496 must provide the following functions::
1497 - field_instance.write(data)
1498 - field_instance.finalize()
1499 - field_instance.close()
1500 - field_instance.set_none()
1501 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1502 and then any keys present in this dictionary will overwrite the default values.
1503 """
1504
1505 #: This is the default configuration for our form parser.
1506 #: Note: all file sizes should be in bytes.
1507 DEFAULT_CONFIG: FormParserConfig = {
1508 "MAX_BODY_SIZE": float("inf"),
1509 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1510 "UPLOAD_DIR": None,
1511 "UPLOAD_KEEP_FILENAME": False,
1512 "UPLOAD_KEEP_EXTENSIONS": False,
1513 # Error on invalid Content-Transfer-Encoding?
1514 "UPLOAD_ERROR_ON_BAD_CTE": False,
1515 }
1516
1517 def __init__(
1518 self,
1519 content_type: str,
1520 on_field: OnFieldCallback | None,
1521 on_file: OnFileCallback | None,
1522 on_end: Callable[[], None] | None = None,
1523 boundary: bytes | str | None = None,
1524 file_name: bytes | None = None,
1525 FileClass: type[FileProtocol] = File,
1526 FieldClass: type[FieldProtocol] = Field,
1527 config: dict[Any, Any] = {},
1528 ) -> None:
1529 self.logger = logging.getLogger(__name__)
1530
1531 # Save variables.
1532 self.content_type = content_type
1533 self.boundary = boundary
1534 self.bytes_received = 0
1535 self.parser = None
1536
1537 # Save callbacks.
1538 self.on_field = on_field
1539 self.on_file = on_file
1540 self.on_end = on_end
1541
1542 # Save classes.
1543 self.FileClass = File
1544 self.FieldClass = Field
1545
1546 # Set configuration options.
1547 self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
1548 self.config.update(config) # type: ignore[typeddict-item]
1549
1550 parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
1551
1552 # Depending on the Content-Type, we instantiate the correct parser.
1553 if content_type == "application/octet-stream":
1554 file: FileProtocol = None # type: ignore
1555
1556 def on_start() -> None:
1557 nonlocal file
1558 file = FileClass(file_name, None, config=cast("FileConfig", self.config))
1559
1560 def on_data(data: bytes, start: int, end: int) -> None:
1561 nonlocal file
1562 file.write(data[start:end])
1563
1564 def _on_end() -> None:
1565 nonlocal file
1566 # Finalize the file itself.
1567 file.finalize()
1568
1569 # Call our callback.
1570 if on_file:
1571 on_file(file)
1572
1573 # Call the on-end callback.
1574 if self.on_end is not None:
1575 self.on_end()
1576
1577 # Instantiate an octet-stream parser
1578 parser = OctetStreamParser(
1579 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1580 max_size=self.config["MAX_BODY_SIZE"],
1581 )
1582
1583 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1584 name_buffer: list[bytes] = []
1585
1586 f: FieldProtocol | None = None
1587
1588 def on_field_start() -> None:
1589 pass
1590
1591 def on_field_name(data: bytes, start: int, end: int) -> None:
1592 name_buffer.append(data[start:end])
1593
1594 def on_field_data(data: bytes, start: int, end: int) -> None:
1595 nonlocal f
1596 if f is None:
1597 f = FieldClass(b"".join(name_buffer))
1598 del name_buffer[:]
1599 f.write(data[start:end])
1600
1601 def on_field_end() -> None:
1602 nonlocal f
1603 # Finalize and call callback.
1604 if f is None:
1605 # If we get here, it's because there was no field data.
1606 # We create a field, set it to None, and then continue.
1607 f = FieldClass(b"".join(name_buffer))
1608 del name_buffer[:]
1609 f.set_none()
1610
1611 f.finalize()
1612 if on_field:
1613 on_field(f)
1614 f = None
1615
1616 def _on_end() -> None:
1617 if self.on_end is not None:
1618 self.on_end()
1619
1620 # Instantiate parser.
1621 parser = QuerystringParser(
1622 callbacks={
1623 "on_field_start": on_field_start,
1624 "on_field_name": on_field_name,
1625 "on_field_data": on_field_data,
1626 "on_field_end": on_field_end,
1627 "on_end": _on_end,
1628 },
1629 max_size=self.config["MAX_BODY_SIZE"],
1630 )
1631
1632 elif content_type == "multipart/form-data":
1633 if boundary is None:
1634 self.logger.error("No boundary given")
1635 raise FormParserError("No boundary given")
1636
1637 header_name: list[bytes] = []
1638 header_value: list[bytes] = []
1639 headers: dict[bytes, bytes] = {}
1640
1641 f_multi: FileProtocol | FieldProtocol | None = None
1642 writer = None
1643 is_file = False
1644
1645 def on_part_begin() -> None:
1646 # Reset headers in case this isn't the first part.
1647 nonlocal headers
1648 headers = {}
1649
1650 def on_part_data(data: bytes, start: int, end: int) -> None:
1651 nonlocal writer
1652 assert writer is not None
1653 writer.write(data[start:end])
1654 # TODO: check for error here.
1655
1656 def on_part_end() -> None:
1657 nonlocal f_multi, is_file
1658 assert f_multi is not None
1659 f_multi.finalize()
1660 if is_file:
1661 if on_file:
1662 on_file(f_multi)
1663 else:
1664 if on_field:
1665 on_field(cast("FieldProtocol", f_multi))
1666
1667 def on_header_field(data: bytes, start: int, end: int) -> None:
1668 header_name.append(data[start:end])
1669
1670 def on_header_value(data: bytes, start: int, end: int) -> None:
1671 header_value.append(data[start:end])
1672
1673 def on_header_end() -> None:
1674 headers[b"".join(header_name)] = b"".join(header_value)
1675 del header_name[:]
1676 del header_value[:]
1677
1678 def on_headers_finished() -> None:
1679 nonlocal is_file, f_multi, writer
1680 # Reset the 'is file' flag.
1681 is_file = False
1682
1683 # Parse the content-disposition header.
1684 # TODO: handle mixed case
1685 content_disp = headers.get(b"Content-Disposition")
1686 disp, options = parse_options_header(content_disp)
1687
1688 # Get the field and filename.
1689 field_name = options.get(b"name")
1690 file_name = options.get(b"filename")
1691 # TODO: check for errors
1692
1693 # Create the proper class.
1694 if file_name is None:
1695 f_multi = FieldClass(field_name)
1696 else:
1697 f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
1698 is_file = True
1699
1700 # Parse the given Content-Transfer-Encoding to determine what
1701 # we need to do with the incoming data.
1702 # TODO: check that we properly handle 8bit / 7bit encoding.
1703 transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")
1704
1705 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1706 writer = f_multi
1707
1708 elif transfer_encoding == b"base64":
1709 writer = Base64Decoder(f_multi)
1710
1711 elif transfer_encoding == b"quoted-printable":
1712 writer = QuotedPrintableDecoder(f_multi)
1713
1714 else:
1715 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1716 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1717 raise FormParserError(f'Unknown Content-Transfer-Encoding "{transfer_encoding!r}"')
1718 else:
1719 # If we aren't erroring, then we just treat this as an
1720 # unencoded Content-Transfer-Encoding.
1721 writer = f_multi
1722
1723 def _on_end() -> None:
1724 nonlocal writer
1725 if writer is not None:
1726 writer.finalize()
1727 if self.on_end is not None:
1728 self.on_end()
1729
1730 # Instantiate a multipart parser.
1731 parser = MultipartParser(
1732 boundary,
1733 callbacks={
1734 "on_part_begin": on_part_begin,
1735 "on_part_data": on_part_data,
1736 "on_part_end": on_part_end,
1737 "on_header_field": on_header_field,
1738 "on_header_value": on_header_value,
1739 "on_header_end": on_header_end,
1740 "on_headers_finished": on_headers_finished,
1741 "on_end": _on_end,
1742 },
1743 max_size=self.config["MAX_BODY_SIZE"],
1744 )
1745
1746 else:
1747 self.logger.warning("Unknown Content-Type: %r", content_type)
1748 raise FormParserError(f"Unknown Content-Type: {content_type}")
1749
1750 self.parser = parser
1751
1752 def write(self, data: bytes) -> int:
1753 """Write some data. The parser will forward this to the appropriate
1754 underlying parser.
1755
1756 Args:
1757 data: The data to write.
1758
1759 Returns:
1760 The number of bytes processed.
1761 """
1762 self.bytes_received += len(data)
1763 # TODO: check the parser's return value for errors?
1764 assert self.parser is not None
1765 return self.parser.write(data)
1766
1767 def finalize(self) -> None:
1768 """Finalize the parser."""
1769 if self.parser is not None and hasattr(self.parser, "finalize"):
1770 self.parser.finalize()
1771
1772 def close(self) -> None:
1773 """Close the parser."""
1774 if self.parser is not None and hasattr(self.parser, "close"):
1775 self.parser.close()
1776
1777 def __repr__(self) -> str:
1778 return f"{self.__class__.__name__}(content_type={self.content_type!r}, parser={self.parser!r})"
1779
1780
1781def create_form_parser(
1782 headers: dict[str, bytes],
1783 on_field: OnFieldCallback | None,
1784 on_file: OnFileCallback | None,
1785 trust_x_headers: bool = False,
1786 config: dict[Any, Any] = {},
1787) -> FormParser:
1788 """This function is a helper function to aid in creating a FormParser
1789 instances. Given a dictionary-like headers object, it will determine
1790 the correct information needed, instantiate a FormParser with the
1791 appropriate values and given callbacks, and then return the corresponding
1792 parser.
1793
1794 Args:
1795 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1796 on_field: Callback to call with each parsed field.
1797 on_file: Callback to call with each parsed file.
1798 trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file
1799 name from X-File-Name.
1800 config: Configuration variables to pass to the FormParser.
1801 """
1802 content_type: str | bytes | None = headers.get("Content-Type")
1803 if content_type is None:
1804 logging.getLogger(__name__).warning("No Content-Type header given")
1805 raise ValueError("No Content-Type header given!")
1806
1807 # Boundaries are optional (the FormParser will raise if one is needed
1808 # but not given).
1809 content_type, params = parse_options_header(content_type)
1810 boundary = params.get(b"boundary")
1811
1812 # We need content_type to be a string, not a bytes object.
1813 content_type = content_type.decode("latin-1")
1814
1815 # File names are optional.
1816 file_name = headers.get("X-File-Name")
1817
1818 # Instantiate a form parser.
1819 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config)
1820
1821 # Return our parser.
1822 return form_parser
1823
1824
1825def parse_form(
1826 headers: dict[str, bytes],
1827 input_stream: SupportsRead,
1828 on_field: OnFieldCallback | None,
1829 on_file: OnFileCallback | None,
1830 chunk_size: int = 1048576,
1831) -> None:
1832 """This function is useful if you just want to parse a request body,
1833 without too much work. Pass it a dictionary-like object of the request's
1834 headers, and a file-like object for the input stream, along with two
1835 callbacks that will get called whenever a field or file is parsed.
1836
1837 Args:
1838 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1839 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1840 on_field: Callback to call with each parsed field.
1841 on_file: Callback to call with each parsed file.
1842 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1843 Defaults to 1 MiB.
1844 """
1845 # Create our form parser.
1846 parser = create_form_parser(headers, on_field, on_file)
1847
1848 # Read chunks of 1MiB and write to the parser, but never read more than
1849 # the given Content-Length, if any.
1850 content_length: int | float | bytes | None = headers.get("Content-Length")
1851 if content_length is not None:
1852 content_length = int(content_length)
1853 else:
1854 content_length = float("inf")
1855 bytes_read = 0
1856
1857 while True:
1858 # Read only up to the Content-Length given.
1859 max_readable = int(min(content_length - bytes_read, chunk_size))
1860 buff = input_stream.read(max_readable)
1861
1862 # Write to the parser and update our length.
1863 parser.write(buff)
1864 bytes_read += len(buff)
1865
1866 # If we get a buffer that's smaller than the size requested, or if we
1867 # have read up to our content length, we're done.
1868 if len(buff) != max_readable or bytes_read == content_length:
1869 break
1870
1871 # Tell our parser that we're done writing data.
1872 parser.finalize()