1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import sys
7import tempfile
8from email.message import Message
9from enum import IntEnum
10from io import BufferedRandom, BytesIO
11from numbers import Number
12from typing import TYPE_CHECKING, cast
13
14from .decoders import Base64Decoder, QuotedPrintableDecoder
15from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
16
17if TYPE_CHECKING: # pragma: no cover
18 from collections.abc import Callable
19 from typing import Any, Literal, Protocol, TypeAlias, TypedDict
20
21 class SupportsRead(Protocol):
22 def read(self, __n: int) -> bytes: ...
23
24 class QuerystringCallbacks(TypedDict, total=False):
25 on_field_start: Callable[[], None]
26 on_field_name: Callable[[bytes, int, int], None]
27 on_field_data: Callable[[bytes, int, int], None]
28 on_field_end: Callable[[], None]
29 on_end: Callable[[], None]
30
31 class OctetStreamCallbacks(TypedDict, total=False):
32 on_start: Callable[[], None]
33 on_data: Callable[[bytes, int, int], None]
34 on_end: Callable[[], None]
35
36 class MultipartCallbacks(TypedDict, total=False):
37 on_part_begin: Callable[[], None]
38 on_part_data: Callable[[bytes, int, int], None]
39 on_part_end: Callable[[], None]
40 on_header_begin: Callable[[], None]
41 on_header_field: Callable[[bytes, int, int], None]
42 on_header_value: Callable[[bytes, int, int], None]
43 on_header_end: Callable[[], None]
44 on_headers_finished: Callable[[], None]
45 on_end: Callable[[], None]
46
47 class FormParserConfig(TypedDict):
48 UPLOAD_DIR: str | None
49 UPLOAD_KEEP_FILENAME: bool
50 UPLOAD_KEEP_EXTENSIONS: bool
51 UPLOAD_ERROR_ON_BAD_CTE: bool
52 MAX_MEMORY_FILE_SIZE: int
53 MAX_BODY_SIZE: float
54
55 class FileConfig(TypedDict, total=False):
56 UPLOAD_DIR: str | bytes | None
57 UPLOAD_DELETE_TMP: bool
58 UPLOAD_KEEP_FILENAME: bool
59 UPLOAD_KEEP_EXTENSIONS: bool
60 MAX_MEMORY_FILE_SIZE: int
61
62 class _FormProtocol(Protocol):
63 def write(self, data: bytes) -> int: ...
64
65 def finalize(self) -> None: ...
66
67 def close(self) -> None: ...
68
69 class FieldProtocol(_FormProtocol, Protocol):
70 def __init__(self, name: bytes | None) -> None: ...
71
72 def set_none(self) -> None: ...
73
74 class FileProtocol(_FormProtocol, Protocol):
75 def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
76
77 OnFieldCallback = Callable[[FieldProtocol], None]
78 OnFileCallback = Callable[[FileProtocol], None]
79
80 CallbackName: TypeAlias = Literal[
81 "start",
82 "data",
83 "end",
84 "field_start",
85 "field_name",
86 "field_data",
87 "field_end",
88 "part_begin",
89 "part_data",
90 "part_end",
91 "header_begin",
92 "header_field",
93 "header_value",
94 "header_end",
95 "headers_finished",
96 ]
97
98# Unique missing object.
99_missing = object()
100
101
102class QuerystringState(IntEnum):
103 """Querystring parser states.
104
105 These are used to keep track of the state of the parser, and are used to determine
106 what to do when new data is encountered.
107 """
108
109 BEFORE_FIELD = 0
110 FIELD_NAME = 1
111 FIELD_DATA = 2
112
113
114class MultipartState(IntEnum):
115 """Multipart parser states.
116
117 These are used to keep track of the state of the parser, and are used to determine
118 what to do when new data is encountered.
119 """
120
121 START = 0
122 START_BOUNDARY = 1
123 HEADER_FIELD_START = 2
124 HEADER_FIELD = 3
125 HEADER_VALUE_START = 4
126 HEADER_VALUE = 5
127 HEADER_VALUE_ALMOST_DONE = 6
128 HEADERS_ALMOST_DONE = 7
129 PART_DATA_START = 8
130 PART_DATA = 9
131 PART_DATA_END = 10
132 END_BOUNDARY = 11
133 END = 12
134
135
136# Flags for the multipart parser.
137FLAG_PART_BOUNDARY = 1
138FLAG_LAST_BOUNDARY = 2
139
140# Get constants. Since iterating over a str on Python 2 gives you a 1-length
141# string, but iterating over a bytes object on Python 3 gives you an integer,
142# we need to save these constants.
143CR = b"\r"[0]
144LF = b"\n"[0]
145COLON = b":"[0]
146SPACE = b" "[0]
147HYPHEN = b"-"[0]
148AMPERSAND = b"&"[0]
149SEMICOLON = b";"[0]
150LOWER_A = b"a"[0]
151LOWER_Z = b"z"[0]
152NULL = b"\x00"[0]
153
154# fmt: off
155# Mask for ASCII characters that can be http tokens.
156# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
157# and these: !#$%&'*+-.^_`|~
158TOKEN_CHARS_SET = frozenset(
159 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
160 b"abcdefghijklmnopqrstuvwxyz"
161 b"0123456789"
162 b"!#$%&'*+-.^_`|~")
163# fmt: on
164
165
166def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
167 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
168 # Uses email.message.Message to parse the header as described in PEP 594.
169 # Ref: https://peps.python.org/pep-0594/#cgi
170 if not value:
171 return (b"", {})
172
173 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
174 if isinstance(value, bytes): # pragma: no cover
175 value = value.decode("latin-1")
176
177 # For types
178 assert isinstance(value, str), "Value should be a string by now"
179
180 # If we have no options, return the string as-is.
181 if ";" not in value:
182 return (value.lower().strip().encode("latin-1"), {})
183
184 # Split at the first semicolon, to get our value and then options.
185 # ctype, rest = value.split(b';', 1)
186 message = Message()
187 message["content-type"] = value
188 params = message.get_params()
189 # If there were no parameters, this would have already returned above
190 assert params, "At least the content type value should be present"
191 ctype = params.pop(0)[0].encode("latin-1")
192 options: dict[bytes, bytes] = {}
193 for param in params:
194 key, value = param
195 # If the value returned from get_params() is a 3-tuple, the last
196 # element corresponds to the value.
197 # See: https://docs.python.org/3/library/email.compat32-message.html
198 if isinstance(value, tuple):
199 value = value[-1]
200 # If the value is a filename, we need to fix a bug on IE6 that sends
201 # the full file path instead of the filename.
202 if key == "filename":
203 if value[1:3] == ":\\" or value[:2] == "\\\\":
204 value = value.split("\\")[-1]
205 options[key.encode("latin-1")] = value.encode("latin-1")
206 return ctype, options
207
208
209class Field:
210 """A Field object represents a (parsed) form field. It represents a single
211 field with a corresponding name and value.
212
213 The name that a :class:`Field` will be instantiated with is the same name
214 that would be found in the following HTML::
215
216 <input name="name_goes_here" type="text"/>
217
218 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
219 will be called when data is written to the Field, and when the Field is
220 finalized, respectively.
221
222 Args:
223 name: The name of the form field.
224 """
225
226 def __init__(self, name: bytes | None) -> None:
227 self._name = name
228 self._value: list[bytes] = []
229
230 # We cache the joined version of _value for speed.
231 self._cache = _missing
232
233 @classmethod
234 def from_value(cls, name: bytes, value: bytes | None) -> Field:
235 """Create an instance of a :class:`Field`, and set the corresponding
236 value - either None or an actual value. This method will also
237 finalize the Field itself.
238
239 Args:
240 name: the name of the form field.
241 value: the value of the form field - either a bytestring or None.
242
243 Returns:
244 A new instance of a [`Field`][python_multipart.Field].
245 """
246
247 f = cls(name)
248 if value is None:
249 f.set_none()
250 else:
251 f.write(value)
252 f.finalize()
253 return f
254
255 def write(self, data: bytes) -> int:
256 """Write some data into the form field.
257
258 Args:
259 data: The data to write to the field.
260
261 Returns:
262 The number of bytes written.
263 """
264 return self.on_data(data)
265
266 def on_data(self, data: bytes) -> int:
267 """This method is a callback that will be called whenever data is
268 written to the Field.
269
270 Args:
271 data: The data to write to the field.
272
273 Returns:
274 The number of bytes written.
275 """
276 self._value.append(data)
277 self._cache = _missing
278 return len(data)
279
280 def on_end(self) -> None:
281 """This method is called whenever the Field is finalized."""
282 if self._cache is _missing:
283 self._cache = b"".join(self._value)
284
285 def finalize(self) -> None:
286 """Finalize the form field."""
287 self.on_end()
288
289 def close(self) -> None:
290 """Close the Field object. This will free any underlying cache."""
291 # Free our value array.
292 if self._cache is _missing:
293 self._cache = b"".join(self._value)
294
295 del self._value
296
297 def set_none(self) -> None:
298 """Some fields in a querystring can possibly have a value of None - for
299 example, the string "foo&bar=&baz=asdf" will have a field with the
300 name "foo" and value None, one with name "bar" and value "", and one
301 with name "baz" and value "asdf". Since the write() interface doesn't
302 support writing None, this function will set the field value to None.
303 """
304 self._cache = None
305
306 @property
307 def field_name(self) -> bytes | None:
308 """This property returns the name of the field."""
309 return self._name
310
311 @property
312 def value(self) -> bytes | None:
313 """This property returns the value of the form field."""
314 if self._cache is _missing:
315 self._cache = b"".join(self._value)
316
317 assert isinstance(self._cache, bytes) or self._cache is None
318 return self._cache
319
320 def __eq__(self, other: object) -> bool:
321 if isinstance(other, Field):
322 return self.field_name == other.field_name and self.value == other.value
323 else:
324 return NotImplemented
325
326 def __repr__(self) -> str:
327 if self.value is not None and len(self.value) > 97:
328 # We get the repr, and then insert three dots before the final
329 # quote.
330 v = repr(self.value[:97])[:-1] + "...'"
331 else:
332 v = repr(self.value)
333
334 return f"{self.__class__.__name__}(field_name={self.field_name!r}, value={v})"
335
336
337class File:
338 """This class represents an uploaded file. It handles writing file data to
339 either an in-memory file or a temporary file on-disk, if the optional
340 threshold is passed.
341
342 There are some options that can be passed to the File to change behavior
343 of the class. Valid options are as follows:
344
345 | Name | Type | Default | Description |
346 |-----------------------|-------|---------|-------------|
347 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
348 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
349 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
350 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
351 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
352
353 Args:
354 file_name: The name of the file that this [`File`][python_multipart.File] represents.
355 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
356 the file was uploaded with Content-Type application/octet-stream.
357 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
358 """ # noqa: E501
359
360 def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
361 # Save configuration, set other variables default.
362 self.logger = logging.getLogger(__name__)
363 self._config = config
364 self._in_memory = True
365 self._bytes_written = 0
366 self._fileobj: BytesIO | BufferedRandom = BytesIO()
367
368 # Save the provided field/file name.
369 self._field_name = field_name
370 self._file_name = file_name
371
372 # Our actual file name is None by default, since, depending on our
373 # config, we may not actually use the provided name.
374 self._actual_file_name: bytes | None = None
375
376 # Split the extension from the filename.
377 if file_name is not None:
378 # Extract just the basename to avoid directory traversal
379 basename = os.path.basename(file_name)
380 base, ext = os.path.splitext(basename)
381 self._file_base = base
382 self._ext = ext
383
384 @property
385 def field_name(self) -> bytes | None:
386 """The form field associated with this file. May be None if there isn't
387 one, for example when we have an application/octet-stream upload.
388 """
389 return self._field_name
390
391 @property
392 def file_name(self) -> bytes | None:
393 """The file name given in the upload request."""
394 return self._file_name
395
396 @property
397 def actual_file_name(self) -> bytes | None:
398 """The file name that this file is saved as. Will be None if it's not
399 currently saved on disk.
400 """
401 return self._actual_file_name
402
403 @property
404 def file_object(self) -> BytesIO | BufferedRandom:
405 """The file object that we're currently writing to. Note that this
406 will either be an instance of a :class:`io.BytesIO`, or a regular file
407 object.
408 """
409 return self._fileobj
410
411 @property
412 def size(self) -> int:
413 """The total size of this file, counted as the number of bytes that
414 currently have been written to the file.
415 """
416 return self._bytes_written
417
418 @property
419 def in_memory(self) -> bool:
420 """A boolean representing whether or not this file object is currently
421 stored in-memory or on-disk.
422 """
423 return self._in_memory
424
425 def flush_to_disk(self) -> None:
426 """If the file is already on-disk, do nothing. Otherwise, copy from
427 the in-memory buffer to a disk file, and then reassign our internal
428 file object to this new disk file.
429
430 Note that if you attempt to flush a file that is already on-disk, a
431 warning will be logged to this module's logger.
432 """
433 if not self._in_memory:
434 self.logger.warning("Trying to flush to disk when we're not in memory")
435 return
436
437 # Go back to the start of our file.
438 self._fileobj.seek(0)
439
440 # Open a new file.
441 new_file = self._get_disk_file()
442
443 # Copy the file objects.
444 shutil.copyfileobj(self._fileobj, new_file)
445
446 # Seek to the new position in our new file.
447 new_file.seek(self._bytes_written)
448
449 # Reassign the fileobject.
450 old_fileobj = self._fileobj
451 self._fileobj = new_file
452
453 # We're no longer in memory.
454 self._in_memory = False
455
456 # Close the old file object.
457 old_fileobj.close()
458
459 def _get_disk_file(self) -> BufferedRandom:
460 """This function is responsible for getting a file object on-disk for us."""
461 self.logger.info("Opening a file on disk")
462
463 file_dir = self._config.get("UPLOAD_DIR")
464 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
465 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
466 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
467 tmp_file: None | BufferedRandom = None
468
469 # If we have a directory and are to keep the filename...
470 if file_dir is not None and keep_filename:
471 self.logger.info("Saving with filename in: %r", file_dir)
472
473 # Build our filename.
474 # TODO: what happens if we don't have a filename?
475 fname = self._file_base + self._ext if keep_extensions else self._file_base
476
477 path = os.path.join(file_dir, fname) # type: ignore[arg-type]
478 try:
479 self.logger.info("Opening file: %r", path)
480 tmp_file = open(path, "w+b")
481 except OSError:
482 tmp_file = None
483
484 self.logger.exception("Error opening temporary file")
485 raise FileError("Error opening temporary file: %r" % path)
486 else:
487 # Build options array.
488 # Note that on Python 3, tempfile doesn't support byte names. We
489 # encode our paths using the default filesystem encoding.
490 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
491
492 if file_dir is None:
493 dir = None
494 elif isinstance(file_dir, bytes):
495 dir = file_dir.decode(sys.getfilesystemencoding())
496 else:
497 dir = file_dir # pragma: no cover
498
499 # Create a temporary (named) file with the appropriate settings.
500 self.logger.info(
501 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
502 )
503 try:
504 tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
505 except OSError:
506 self.logger.exception("Error creating named temporary file")
507 raise FileError("Error creating named temporary file")
508
509 assert tmp_file is not None
510 # Encode filename as bytes.
511 if isinstance(tmp_file.name, str):
512 fname = tmp_file.name.encode(sys.getfilesystemencoding())
513 else:
514 fname = cast(bytes, tmp_file.name) # pragma: no cover
515
516 self._actual_file_name = fname
517 return tmp_file
518
519 def write(self, data: bytes) -> int:
520 """Write some data to the File.
521
522 :param data: a bytestring
523 """
524 return self.on_data(data)
525
526 def on_data(self, data: bytes) -> int:
527 """This method is a callback that will be called whenever data is
528 written to the File.
529
530 Args:
531 data: The data to write to the file.
532
533 Returns:
534 The number of bytes written.
535 """
536 bwritten = self._fileobj.write(data)
537
538 # If the bytes written isn't the same as the length, just return.
539 if bwritten != len(data):
540 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
541 return bwritten
542
543 # Keep track of how many bytes we've written.
544 self._bytes_written += bwritten
545
546 # If we're in-memory and are over our limit, we create a file.
547 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
548 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
549 self.logger.info("Flushing to disk")
550 self.flush_to_disk()
551
552 # Return the number of bytes written.
553 return bwritten
554
555 def on_end(self) -> None:
556 """This method is called whenever the Field is finalized."""
557 # Flush the underlying file object
558 self._fileobj.flush()
559
560 def finalize(self) -> None:
561 """Finalize the form file. This will not close the underlying file,
562 but simply signal that we are finished writing to the File.
563 """
564 self.on_end()
565
566 def close(self) -> None:
567 """Close the File object. This will actually close the underlying
568 file object (whether it's a :class:`io.BytesIO` or an actual file
569 object).
570 """
571 self._fileobj.close()
572
573 def __repr__(self) -> str:
574 return f"{self.__class__.__name__}(file_name={self.file_name!r}, field_name={self.field_name!r})"
575
576
577class BaseParser:
578 """This class is the base class for all parsers. It contains the logic for
579 calling and adding callbacks.
580
581 A callback can be one of two different forms. "Notification callbacks" are
582 callbacks that are called when something happens - for example, when a new
583 part of a multipart message is encountered by the parser. "Data callbacks"
584 are called when we get some sort of data - for example, part of the body of
585 a multipart chunk. Notification callbacks are called with no parameters,
586 whereas data callbacks are called with three, as follows::
587
588 data_callback(data, start, end)
589
590 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
591 Python 3). "start" and "end" are integer indexes into the "data" string
592 that represent the data of interest. Thus, in a data callback, the slice
593 `data[start:end]` represents the data that the callback is "interested in".
594 The callback is not passed a copy of the data, since copying severely hurts
595 performance.
596 """
597
598 def __init__(self) -> None:
599 self.logger = logging.getLogger(__name__)
600 self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
601
602 def callback(
603 self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
604 ) -> None:
605 """This function calls a provided callback with some data. If the
606 callback is not set, will do nothing.
607
608 Args:
609 name: The name of the callback to call (as a string).
610 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
611 callback, and no parameters are given.
612 end: An integer that is passed to the data callback.
613 start: An integer that is passed to the data callback.
614 """
615 on_name = "on_" + name
616 func = self.callbacks.get(on_name)
617 if func is None:
618 return
619 func = cast("Callable[..., Any]", func)
620 # Depending on whether we're given a buffer...
621 if data is not None:
622 # Don't do anything if we have start == end.
623 if start is not None and start == end:
624 return
625
626 self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
627 func(data, start, end)
628 else:
629 self.logger.debug("Calling %s with no data", on_name)
630 func()
631
632 def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
633 """Update the function for a callback. Removes from the callbacks dict
634 if new_func is None.
635
636 :param name: The name of the callback to call (as a string).
637
638 :param new_func: The new function for the callback. If None, then the
639 callback will be removed (with no error if it does not
640 exist).
641 """
642 if new_func is None:
643 self.callbacks.pop("on_" + name, None) # type: ignore[misc]
644 else:
645 self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
646
647 def close(self) -> None:
648 pass # pragma: no cover
649
650 def finalize(self) -> None:
651 pass # pragma: no cover
652
653 def __repr__(self) -> str:
654 return "%s()" % self.__class__.__name__
655
656
657class OctetStreamParser(BaseParser):
658 """This parser parses an octet-stream request body and calls callbacks when
659 incoming data is received. Callbacks are as follows:
660
661 | Callback Name | Parameters | Description |
662 |----------------|-----------------|-----------------------------------------------------|
663 | on_start | None | Called when the first data is parsed. |
664 | on_data | data, start, end| Called for each data chunk that is parsed. |
665 | on_end | None | Called when the parser is finished parsing all data.|
666
667 Args:
668 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
669 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
670 """
671
672 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
673 super().__init__()
674 self.callbacks = callbacks
675 self._started = False
676
677 if not isinstance(max_size, Number) or max_size < 1:
678 raise ValueError("max_size must be a positive number, not %r" % max_size)
679 self.max_size: int | float = max_size
680 self._current_size = 0
681
682 def write(self, data: bytes) -> int:
683 """Write some data to the parser, which will perform size verification,
684 and then pass the data to the underlying callback.
685
686 Args:
687 data: The data to write to the parser.
688
689 Returns:
690 The number of bytes written.
691 """
692 if not self._started:
693 self.callback("start")
694 self._started = True
695
696 # Truncate data length.
697 data_len = len(data)
698 if (self._current_size + data_len) > self.max_size:
699 # We truncate the length of data that we are to process.
700 new_size = int(self.max_size - self._current_size)
701 self.logger.warning(
702 "Current size is %d (max %d), so truncating data length from %d to %d",
703 self._current_size,
704 self.max_size,
705 data_len,
706 new_size,
707 )
708 data_len = new_size
709
710 # Increment size, then callback, in case there's an exception.
711 self._current_size += data_len
712 self.callback("data", data, 0, data_len)
713 return data_len
714
715 def finalize(self) -> None:
716 """Finalize this parser, which signals to that we are finished parsing,
717 and sends the on_end callback.
718 """
719 self.callback("end")
720
721 def __repr__(self) -> str:
722 return "%s()" % self.__class__.__name__
723
724
725class QuerystringParser(BaseParser):
726 """This is a streaming querystring parser. It will consume data, and call
727 the callbacks given when it has data.
728
729 | Callback Name | Parameters | Description |
730 |----------------|-----------------|-----------------------------------------------------|
731 | on_field_start | None | Called when a new field is encountered. |
732 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
733 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
734 | on_field_end | None | Called when the end of a field is encountered. |
735 | on_end | None | Called when the parser is finished parsing all data.|
736
737 Args:
738 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
739 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
740 behavior of the parser changes as the following: if a field has a value with an equal sign
741 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
742 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
743 then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
744 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
745 """ # noqa: E501
746
747 state: QuerystringState
748
749 def __init__(
750 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
751 ) -> None:
752 super().__init__()
753 self.state = QuerystringState.BEFORE_FIELD
754 self._found_sep = False
755
756 self.callbacks = callbacks
757
758 # Max-size stuff
759 if not isinstance(max_size, Number) or max_size < 1:
760 raise ValueError("max_size must be a positive number, not %r" % max_size)
761 self.max_size: int | float = max_size
762 self._current_size = 0
763
764 # Should parsing be strict?
765 self.strict_parsing = strict_parsing
766
767 def write(self, data: bytes) -> int:
768 """Write some data to the parser, which will perform size verification,
769 parse into either a field name or value, and then pass the
770 corresponding data to the underlying callback. If an error is
771 encountered while parsing, a QuerystringParseError will be raised. The
772 "offset" attribute of the raised exception will be set to the offset in
773 the input data chunk (NOT the overall stream) that caused the error.
774
775 Args:
776 data: The data to write to the parser.
777
778 Returns:
779 The number of bytes written.
780 """
781 # Handle sizing.
782 data_len = len(data)
783 if (self._current_size + data_len) > self.max_size:
784 # We truncate the length of data that we are to process.
785 new_size = int(self.max_size - self._current_size)
786 self.logger.warning(
787 "Current size is %d (max %d), so truncating data length from %d to %d",
788 self._current_size,
789 self.max_size,
790 data_len,
791 new_size,
792 )
793 data_len = new_size
794
795 l = 0
796 try:
797 l = self._internal_write(data, data_len)
798 finally:
799 self._current_size += l
800
801 return l
802
803 def _internal_write(self, data: bytes, length: int) -> int:
804 state = self.state
805 strict_parsing = self.strict_parsing
806 found_sep = self._found_sep
807
808 i = 0
809 while i < length:
810 ch = data[i]
811
812 # Depending on our state...
813 if state == QuerystringState.BEFORE_FIELD:
814 # If the 'found_sep' flag is set, we've already encountered
815 # and skipped a single separator. If so, we check our strict
816 # parsing flag and decide what to do. Otherwise, we haven't
817 # yet reached a separator, and thus, if we do, we need to skip
818 # it as it will be the boundary between fields that's supposed
819 # to be there.
820 if ch == AMPERSAND or ch == SEMICOLON:
821 if found_sep:
822 # If we're parsing strictly, we disallow blank chunks.
823 if strict_parsing:
824 e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i)
825 e.offset = i
826 raise e
827 else:
828 self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
829 else:
830 # This case is when we're skipping the (first)
831 # separator between fields, so we just set our flag
832 # and continue on.
833 found_sep = True
834 else:
835 # Emit a field-start event, and go to that state. Also,
836 # reset the "found_sep" flag, for the next time we get to
837 # this state.
838 self.callback("field_start")
839 i -= 1
840 state = QuerystringState.FIELD_NAME
841 found_sep = False
842
843 elif state == QuerystringState.FIELD_NAME:
844 # Try and find a separator - we ensure that, if we do, we only
845 # look for the equal sign before it.
846 sep_pos = data.find(b"&", i)
847 if sep_pos == -1:
848 sep_pos = data.find(b";", i)
849
850 # See if we can find an equals sign in the remaining data. If
851 # so, we can immediately emit the field name and jump to the
852 # data state.
853 if sep_pos != -1:
854 equals_pos = data.find(b"=", i, sep_pos)
855 else:
856 equals_pos = data.find(b"=", i)
857
858 if equals_pos != -1:
859 # Emit this name.
860 self.callback("field_name", data, i, equals_pos)
861
862 # Jump i to this position. Note that it will then have 1
863 # added to it below, which means the next iteration of this
864 # loop will inspect the character after the equals sign.
865 i = equals_pos
866 state = QuerystringState.FIELD_DATA
867 else:
868 # No equals sign found.
869 if not strict_parsing:
870 # See also comments in the QuerystringState.FIELD_DATA case below.
871 # If we found the separator, we emit the name and just
872 # end - there's no data callback at all (not even with
873 # a blank value).
874 if sep_pos != -1:
875 self.callback("field_name", data, i, sep_pos)
876 self.callback("field_end")
877
878 i = sep_pos - 1
879 state = QuerystringState.BEFORE_FIELD
880 else:
881 # Otherwise, no separator in this block, so the
882 # rest of this chunk must be a name.
883 self.callback("field_name", data, i, length)
884 i = length
885
886 else:
887 # We're parsing strictly. If we find a separator,
888 # this is an error - we require an equals sign.
889 if sep_pos != -1:
890 e = QuerystringParseError(
891 "When strict_parsing is True, we require an "
892 "equals sign in all field chunks. Did not "
893 "find one in the chunk that starts at %d" % (i,)
894 )
895 e.offset = i
896 raise e
897
898 # No separator in the rest of this chunk, so it's just
899 # a field name.
900 self.callback("field_name", data, i, length)
901 i = length
902
903 elif state == QuerystringState.FIELD_DATA:
904 # Try finding either an ampersand or a semicolon after this
905 # position.
906 sep_pos = data.find(b"&", i)
907 if sep_pos == -1:
908 sep_pos = data.find(b";", i)
909
910 # If we found it, callback this bit as data and then go back
911 # to expecting to find a field.
912 if sep_pos != -1:
913 self.callback("field_data", data, i, sep_pos)
914 self.callback("field_end")
915
916 # Note that we go to the separator, which brings us to the
917 # "before field" state. This allows us to properly emit
918 # "field_start" events only when we actually have data for
919 # a field of some sort.
920 i = sep_pos - 1
921 state = QuerystringState.BEFORE_FIELD
922
923 # Otherwise, emit the rest as data and finish.
924 else:
925 self.callback("field_data", data, i, length)
926 i = length
927
928 else: # pragma: no cover (error case)
929 msg = "Reached an unknown state %d at %d" % (state, i)
930 self.logger.warning(msg)
931 e = QuerystringParseError(msg)
932 e.offset = i
933 raise e
934
935 i += 1
936
937 self.state = state
938 self._found_sep = found_sep
939 return len(data)
940
941 def finalize(self) -> None:
942 """Finalize this parser, which signals to that we are finished parsing,
943 if we're still in the middle of a field, an on_field_end callback, and
944 then the on_end callback.
945 """
946 # If we're currently in the middle of a field, we finish it.
947 if self.state == QuerystringState.FIELD_DATA:
948 self.callback("field_end")
949 self.callback("end")
950
951 def __repr__(self) -> str:
952 return "{}(strict_parsing={!r}, max_size={!r})".format(
953 self.__class__.__name__, self.strict_parsing, self.max_size
954 )
955
956
957class MultipartParser(BaseParser):
958 """This class is a streaming multipart/form-data parser.
959
960 | Callback Name | Parameters | Description |
961 |--------------------|-----------------|-------------|
962 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
963 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
964 | on_part_end | None | Called when the end of a part is reached. |
965 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
966 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
967 | on_header_value | data, start, end| Called when we get data for a header. |
968 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
969 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
970 | on_end | None | Called when the parser is finished parsing all data. |
971
972 Args:
973 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
974 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
975 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
976 """ # noqa: E501
977
978 def __init__(
979 self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
980 ) -> None:
981 # Initialize parser state.
982 super().__init__()
983 self.state = MultipartState.START
984 self.index = self.flags = 0
985
986 self.callbacks = callbacks
987
988 if not isinstance(max_size, Number) or max_size < 1:
989 raise ValueError("max_size must be a positive number, not %r" % max_size)
990 self.max_size = max_size
991 self._current_size = 0
992
993 # Setup marks. These are used to track the state of data received.
994 self.marks: dict[str, int] = {}
995
996 # Save our boundary.
997 if isinstance(boundary, str): # pragma: no cover
998 boundary = boundary.encode("latin-1")
999 self.boundary = b"\r\n--" + boundary
1000
1001 def write(self, data: bytes) -> int:
1002 """Write some data to the parser, which will perform size verification,
1003 and then parse the data into the appropriate location (e.g. header,
1004 data, etc.), and pass this on to the underlying callback. If an error
1005 is encountered, a MultipartParseError will be raised. The "offset"
1006 attribute on the raised exception will be set to the offset of the byte
1007 in the input chunk that caused the error.
1008
1009 Args:
1010 data: The data to write to the parser.
1011
1012 Returns:
1013 The number of bytes written.
1014 """
1015 # Handle sizing.
1016 data_len = len(data)
1017 if (self._current_size + data_len) > self.max_size:
1018 # We truncate the length of data that we are to process.
1019 new_size = int(self.max_size - self._current_size)
1020 self.logger.warning(
1021 "Current size is %d (max %d), so truncating data length from %d to %d",
1022 self._current_size,
1023 self.max_size,
1024 data_len,
1025 new_size,
1026 )
1027 data_len = new_size
1028
1029 l = 0
1030 try:
1031 l = self._internal_write(data, data_len)
1032 finally:
1033 self._current_size += l
1034
1035 return l
1036
1037 def _internal_write(self, data: bytes, length: int) -> int:
1038 # Get values from locals.
1039 boundary = self.boundary
1040
1041 # Get our state, flags and index. These are persisted between calls to
1042 # this function.
1043 state = self.state
1044 index = self.index
1045 flags = self.flags
1046
1047 # Our index defaults to 0.
1048 i = 0
1049
1050 # Set a mark.
1051 def set_mark(name: str) -> None:
1052 self.marks[name] = i
1053
1054 # Remove a mark.
1055 def delete_mark(name: str, reset: bool = False) -> None:
1056 self.marks.pop(name, None)
1057
1058 # Helper function that makes calling a callback with data easier. The
1059 # 'remaining' parameter will callback from the marked value until the
1060 # end of the buffer, and reset the mark, instead of deleting it. This
1061 # is used at the end of the function to call our callbacks with any
1062 # remaining data in this chunk.
1063 def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
1064 marked_index = self.marks.get(name)
1065 if marked_index is None:
1066 return
1067
1068 # Otherwise, we call it from the mark to the current byte we're
1069 # processing.
1070 if end_i <= marked_index:
1071 # There is no additional data to send.
1072 pass
1073 elif marked_index >= 0:
1074 # We are emitting data from the local buffer.
1075 self.callback(name, data, marked_index, end_i)
1076 else:
1077 # Some of the data comes from a partial boundary match.
1078 # and requires look-behind.
1079 # We need to use self.flags (and not flags) because we care about
1080 # the state when we entered the loop.
1081 lookbehind_len = -marked_index
1082 if lookbehind_len <= len(boundary):
1083 self.callback(name, boundary, 0, lookbehind_len)
1084 elif self.flags & FLAG_PART_BOUNDARY:
1085 lookback = boundary + b"\r\n"
1086 self.callback(name, lookback, 0, lookbehind_len)
1087 elif self.flags & FLAG_LAST_BOUNDARY:
1088 lookback = boundary + b"--\r\n"
1089 self.callback(name, lookback, 0, lookbehind_len)
1090 else: # pragma: no cover (error case)
1091 self.logger.warning("Look-back buffer error")
1092
1093 if end_i > 0:
1094 self.callback(name, data, 0, end_i)
1095 # If we're getting remaining data, we have got all the data we
1096 # can be certain is not a boundary, leaving only a partial boundary match.
1097 if remaining:
1098 self.marks[name] = end_i - length
1099 else:
1100 self.marks.pop(name, None)
1101
1102 # For each byte...
1103 while i < length:
1104 c = data[i]
1105
1106 if state == MultipartState.START:
1107 # Skip leading newlines
1108 if c == CR or c == LF:
1109 i += 1
1110 continue
1111
1112 # index is used as in index into our boundary. Set to 0.
1113 index = 0
1114
1115 # Move to the next state, but decrement i so that we re-process
1116 # this character.
1117 state = MultipartState.START_BOUNDARY
1118 i -= 1
1119
1120 elif state == MultipartState.START_BOUNDARY:
1121 # Check to ensure that the last 2 characters in our boundary
1122 # are CRLF.
1123 if index == len(boundary) - 2:
1124 if c == HYPHEN:
1125 # Potential empty message.
1126 state = MultipartState.END_BOUNDARY
1127 elif c != CR:
1128 # Error!
1129 msg = "Did not find CR at end of boundary (%d)" % (i,)
1130 self.logger.warning(msg)
1131 e = MultipartParseError(msg)
1132 e.offset = i
1133 raise e
1134
1135 index += 1
1136
1137 elif index == len(boundary) - 2 + 1:
1138 if c != LF:
1139 msg = "Did not find LF at end of boundary (%d)" % (i,)
1140 self.logger.warning(msg)
1141 e = MultipartParseError(msg)
1142 e.offset = i
1143 raise e
1144
1145 # The index is now used for indexing into our boundary.
1146 index = 0
1147
1148 # Callback for the start of a part.
1149 self.callback("part_begin")
1150
1151 # Move to the next character and state.
1152 state = MultipartState.HEADER_FIELD_START
1153
1154 else:
1155 # Check to ensure our boundary matches
1156 if c != boundary[index + 2]:
1157 msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
1158 self.logger.warning(msg)
1159 e = MultipartParseError(msg)
1160 e.offset = i
1161 raise e
1162
1163 # Increment index into boundary and continue.
1164 index += 1
1165
1166 elif state == MultipartState.HEADER_FIELD_START:
1167 # Mark the start of a header field here, reset the index, and
1168 # continue parsing our header field.
1169 index = 0
1170
1171 # Set a mark of our header field.
1172 set_mark("header_field")
1173
1174 # Notify that we're starting a header if the next character is
1175 # not a CR; a CR at the beginning of the header will cause us
1176 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1177 # below.
1178 if c != CR:
1179 self.callback("header_begin")
1180
1181 # Move to parsing header fields.
1182 state = MultipartState.HEADER_FIELD
1183 i -= 1
1184
1185 elif state == MultipartState.HEADER_FIELD:
1186 # If we've reached a CR at the beginning of a header, it means
1187 # that we've reached the second of 2 newlines, and so there are
1188 # no more headers to parse.
1189 if c == CR and index == 0:
1190 delete_mark("header_field")
1191 state = MultipartState.HEADERS_ALMOST_DONE
1192 i += 1
1193 continue
1194
1195 # Increment our index in the header.
1196 index += 1
1197
1198 # If we've reached a colon, we're done with this header.
1199 if c == COLON:
1200 # A 0-length header is an error.
1201 if index == 1:
1202 msg = "Found 0-length header at %d" % (i,)
1203 self.logger.warning(msg)
1204 e = MultipartParseError(msg)
1205 e.offset = i
1206 raise e
1207
1208 # Call our callback with the header field.
1209 data_callback("header_field", i)
1210
1211 # Move to parsing the header value.
1212 state = MultipartState.HEADER_VALUE_START
1213
1214 elif c not in TOKEN_CHARS_SET:
1215 msg = "Found invalid character %r in header at %d" % (c, i)
1216 self.logger.warning(msg)
1217 e = MultipartParseError(msg)
1218 e.offset = i
1219 raise e
1220
1221 elif state == MultipartState.HEADER_VALUE_START:
1222 # Skip leading spaces.
1223 if c == SPACE:
1224 i += 1
1225 continue
1226
1227 # Mark the start of the header value.
1228 set_mark("header_value")
1229
1230 # Move to the header-value state, reprocessing this character.
1231 state = MultipartState.HEADER_VALUE
1232 i -= 1
1233
1234 elif state == MultipartState.HEADER_VALUE:
1235 # If we've got a CR, we're nearly done our headers. Otherwise,
1236 # we do nothing and just move past this character.
1237 if c == CR:
1238 data_callback("header_value", i)
1239 self.callback("header_end")
1240 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1241
1242 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1243 # The last character should be a LF. If not, it's an error.
1244 if c != LF:
1245 msg = f"Did not find LF character at end of header (found {c!r})"
1246 self.logger.warning(msg)
1247 e = MultipartParseError(msg)
1248 e.offset = i
1249 raise e
1250
1251 # Move back to the start of another header. Note that if that
1252 # state detects ANOTHER newline, it'll trigger the end of our
1253 # headers.
1254 state = MultipartState.HEADER_FIELD_START
1255
1256 elif state == MultipartState.HEADERS_ALMOST_DONE:
1257 # We're almost done our headers. This is reached when we parse
1258 # a CR at the beginning of a header, so our next character
1259 # should be a LF, or it's an error.
1260 if c != LF:
1261 msg = f"Did not find LF at end of headers (found {c!r})"
1262 self.logger.warning(msg)
1263 e = MultipartParseError(msg)
1264 e.offset = i
1265 raise e
1266
1267 self.callback("headers_finished")
1268 state = MultipartState.PART_DATA_START
1269
1270 elif state == MultipartState.PART_DATA_START:
1271 # Mark the start of our part data.
1272 set_mark("part_data")
1273
1274 # Start processing part data, including this character.
1275 state = MultipartState.PART_DATA
1276 i -= 1
1277
1278 elif state == MultipartState.PART_DATA:
1279 # We're processing our part data right now. During this, we
1280 # need to efficiently search for our boundary, since any data
1281 # on any number of lines can be a part of the current data.
1282
1283 # Save the current value of our index. We use this in case we
1284 # find part of a boundary, but it doesn't match fully.
1285 prev_index = index
1286
1287 # Set up variables.
1288 boundary_length = len(boundary)
1289 data_length = length
1290
1291 # If our index is 0, we're starting a new part, so start our
1292 # search.
1293 if index == 0:
1294 # The most common case is likely to be that the whole
1295 # boundary is present in the buffer.
1296 # Calling `find` is much faster than iterating here.
1297 i0 = data.find(boundary, i, data_length)
1298 if i0 >= 0:
1299 # We matched the whole boundary string.
1300 index = boundary_length - 1
1301 i = i0 + boundary_length - 1
1302 else:
1303 # No match found for whole string.
1304 # There may be a partial boundary at the end of the
1305 # data, which the find will not match.
1306 # Since the length should to be searched is limited to
1307 # the boundary length, just perform a naive search.
1308 i = max(i, data_length - boundary_length)
1309
1310 # Search forward until we either hit the end of our buffer,
1311 # or reach a potential start of the boundary.
1312 while i < data_length - 1 and data[i] != boundary[0]:
1313 i += 1
1314
1315 c = data[i]
1316
1317 # Now, we have a couple of cases here. If our index is before
1318 # the end of the boundary...
1319 if index < boundary_length:
1320 # If the character matches...
1321 if boundary[index] == c:
1322 # The current character matches, so continue!
1323 index += 1
1324 else:
1325 index = 0
1326
1327 # Our index is equal to the length of our boundary!
1328 elif index == boundary_length:
1329 # First we increment it.
1330 index += 1
1331
1332 # Now, if we've reached a newline, we need to set this as
1333 # the potential end of our boundary.
1334 if c == CR:
1335 flags |= FLAG_PART_BOUNDARY
1336
1337 # Otherwise, if this is a hyphen, we might be at the last
1338 # of all boundaries.
1339 elif c == HYPHEN:
1340 flags |= FLAG_LAST_BOUNDARY
1341
1342 # Otherwise, we reset our index, since this isn't either a
1343 # newline or a hyphen.
1344 else:
1345 index = 0
1346
1347 # Our index is right after the part boundary, which should be
1348 # a LF.
1349 elif index == boundary_length + 1:
1350 # If we're at a part boundary (i.e. we've seen a CR
1351 # character already)...
1352 if flags & FLAG_PART_BOUNDARY:
1353 # We need a LF character next.
1354 if c == LF:
1355 # Unset the part boundary flag.
1356 flags &= ~FLAG_PART_BOUNDARY
1357
1358 # We have identified a boundary, callback for any data before it.
1359 data_callback("part_data", i - index)
1360 # Callback indicating that we've reached the end of
1361 # a part, and are starting a new one.
1362 self.callback("part_end")
1363 self.callback("part_begin")
1364
1365 # Move to parsing new headers.
1366 index = 0
1367 state = MultipartState.HEADER_FIELD_START
1368 i += 1
1369 continue
1370
1371 # We didn't find an LF character, so no match. Reset
1372 # our index and clear our flag.
1373 index = 0
1374 flags &= ~FLAG_PART_BOUNDARY
1375
1376 # Otherwise, if we're at the last boundary (i.e. we've
1377 # seen a hyphen already)...
1378 elif flags & FLAG_LAST_BOUNDARY:
1379 # We need a second hyphen here.
1380 if c == HYPHEN:
1381 # We have identified a boundary, callback for any data before it.
1382 data_callback("part_data", i - index)
1383 # Callback to end the current part, and then the
1384 # message.
1385 self.callback("part_end")
1386 self.callback("end")
1387 state = MultipartState.END
1388 else:
1389 # No match, so reset index.
1390 index = 0
1391
1392 # Otherwise, our index is 0. If the previous index is not, it
1393 # means we reset something, and we need to take the data we
1394 # thought was part of our boundary and send it along as actual
1395 # data.
1396 if index == 0 and prev_index > 0:
1397 # Overwrite our previous index.
1398 prev_index = 0
1399
1400 # Re-consider the current character, since this could be
1401 # the start of the boundary itself.
1402 i -= 1
1403
1404 elif state == MultipartState.END_BOUNDARY:
1405 if index == len(boundary) - 2 + 1:
1406 if c != HYPHEN:
1407 msg = "Did not find - at end of boundary (%d)" % (i,)
1408 self.logger.warning(msg)
1409 e = MultipartParseError(msg)
1410 e.offset = i
1411 raise e
1412 index += 1
1413 self.callback("end")
1414 state = MultipartState.END
1415
1416 elif state == MultipartState.END:
1417 # Don't do anything if chunk ends with CRLF.
1418 if c == CR and i + 1 < length and data[i + 1] == LF:
1419 i += 2
1420 continue
1421 # Skip data after the last boundary.
1422 self.logger.warning("Skipping data after last boundary")
1423 i = length
1424 break
1425
1426 else: # pragma: no cover (error case)
1427 # We got into a strange state somehow! Just stop processing.
1428 msg = "Reached an unknown state %d at %d" % (state, i)
1429 self.logger.warning(msg)
1430 e = MultipartParseError(msg)
1431 e.offset = i
1432 raise e
1433
1434 # Move to the next byte.
1435 i += 1
1436
1437 # We call our callbacks with any remaining data. Note that we pass
1438 # the 'remaining' flag, which sets the mark back to 0 instead of
1439 # deleting it, if it's found. This is because, if the mark is found
1440 # at this point, we assume that there's data for one of these things
1441 # that has been parsed, but not yet emitted. And, as such, it implies
1442 # that we haven't yet reached the end of this 'thing'. So, by setting
1443 # the mark to 0, we cause any data callbacks that take place in future
1444 # calls to this function to start from the beginning of that buffer.
1445 data_callback("header_field", length, True)
1446 data_callback("header_value", length, True)
1447 data_callback("part_data", length - index, True)
1448
1449 # Save values to locals.
1450 self.state = state
1451 self.index = index
1452 self.flags = flags
1453
1454 # Return our data length to indicate no errors, and that we processed
1455 # all of it.
1456 return length
1457
1458 def finalize(self) -> None:
1459 """Finalize this parser, which signals to that we are finished parsing.
1460
1461 Note: It does not currently, but in the future, it will verify that we
1462 are in the final state of the parser (i.e. the end of the multipart
1463 message is well-formed), and, if not, throw an error.
1464 """
1465 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1466 # error or otherwise state that we're not finished parsing.
1467 pass
1468
1469 def __repr__(self) -> str:
1470 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1471
1472
1473class FormParser:
1474 """This class is the all-in-one form parser. Given all the information
1475 necessary to parse a form, it will instantiate the correct parser, create
1476 the proper :class:`Field` and :class:`File` classes to store the data that
1477 is parsed, and call the two given callbacks with each field and file as
1478 they become available.
1479
1480 Args:
1481 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1482 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1483 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1484 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1485 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1486 in the Content-Type header, as a bytestring.
1487 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1488 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1489 manually.
1490 FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class
1491 if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and
1492 it must provide the following functions::
1493 - file_instance.write(data)
1494 - file_instance.finalize()
1495 - file_instance.close()
1496 FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own
1497 class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it
1498 must provide the following functions::
1499 - field_instance.write(data)
1500 - field_instance.finalize()
1501 - field_instance.close()
1502 - field_instance.set_none()
1503 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1504 and then any keys present in this dictionary will overwrite the default values.
1505 """
1506
1507 #: This is the default configuration for our form parser.
1508 #: Note: all file sizes should be in bytes.
1509 DEFAULT_CONFIG: FormParserConfig = {
1510 "MAX_BODY_SIZE": float("inf"),
1511 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1512 "UPLOAD_DIR": None,
1513 "UPLOAD_KEEP_FILENAME": False,
1514 "UPLOAD_KEEP_EXTENSIONS": False,
1515 # Error on invalid Content-Transfer-Encoding?
1516 "UPLOAD_ERROR_ON_BAD_CTE": False,
1517 }
1518
1519 def __init__(
1520 self,
1521 content_type: str,
1522 on_field: OnFieldCallback | None,
1523 on_file: OnFileCallback | None,
1524 on_end: Callable[[], None] | None = None,
1525 boundary: bytes | str | None = None,
1526 file_name: bytes | None = None,
1527 FileClass: type[FileProtocol] = File,
1528 FieldClass: type[FieldProtocol] = Field,
1529 config: dict[Any, Any] = {},
1530 ) -> None:
1531 self.logger = logging.getLogger(__name__)
1532
1533 # Save variables.
1534 self.content_type = content_type
1535 self.boundary = boundary
1536 self.bytes_received = 0
1537 self.parser = None
1538
1539 # Save callbacks.
1540 self.on_field = on_field
1541 self.on_file = on_file
1542 self.on_end = on_end
1543
1544 # Save classes.
1545 self.FileClass = File
1546 self.FieldClass = Field
1547
1548 # Set configuration options.
1549 self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
1550 self.config.update(config) # type: ignore[typeddict-item]
1551
1552 parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
1553
1554 # Depending on the Content-Type, we instantiate the correct parser.
1555 if content_type == "application/octet-stream":
1556 file: FileProtocol = None # type: ignore
1557
1558 def on_start() -> None:
1559 nonlocal file
1560 file = FileClass(file_name, None, config=cast("FileConfig", self.config))
1561
1562 def on_data(data: bytes, start: int, end: int) -> None:
1563 nonlocal file
1564 file.write(data[start:end])
1565
1566 def _on_end() -> None:
1567 nonlocal file
1568 # Finalize the file itself.
1569 file.finalize()
1570
1571 # Call our callback.
1572 if on_file:
1573 on_file(file)
1574
1575 # Call the on-end callback.
1576 if self.on_end is not None:
1577 self.on_end()
1578
1579 # Instantiate an octet-stream parser
1580 parser = OctetStreamParser(
1581 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1582 max_size=self.config["MAX_BODY_SIZE"],
1583 )
1584
1585 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1586 name_buffer: list[bytes] = []
1587
1588 f: FieldProtocol | None = None
1589
1590 def on_field_start() -> None:
1591 pass
1592
1593 def on_field_name(data: bytes, start: int, end: int) -> None:
1594 name_buffer.append(data[start:end])
1595
1596 def on_field_data(data: bytes, start: int, end: int) -> None:
1597 nonlocal f
1598 if f is None:
1599 f = FieldClass(b"".join(name_buffer))
1600 del name_buffer[:]
1601 f.write(data[start:end])
1602
1603 def on_field_end() -> None:
1604 nonlocal f
1605 # Finalize and call callback.
1606 if f is None:
1607 # If we get here, it's because there was no field data.
1608 # We create a field, set it to None, and then continue.
1609 f = FieldClass(b"".join(name_buffer))
1610 del name_buffer[:]
1611 f.set_none()
1612
1613 f.finalize()
1614 if on_field:
1615 on_field(f)
1616 f = None
1617
1618 def _on_end() -> None:
1619 if self.on_end is not None:
1620 self.on_end()
1621
1622 # Instantiate parser.
1623 parser = QuerystringParser(
1624 callbacks={
1625 "on_field_start": on_field_start,
1626 "on_field_name": on_field_name,
1627 "on_field_data": on_field_data,
1628 "on_field_end": on_field_end,
1629 "on_end": _on_end,
1630 },
1631 max_size=self.config["MAX_BODY_SIZE"],
1632 )
1633
1634 elif content_type == "multipart/form-data":
1635 if boundary is None:
1636 self.logger.error("No boundary given")
1637 raise FormParserError("No boundary given")
1638
1639 header_name: list[bytes] = []
1640 header_value: list[bytes] = []
1641 headers: dict[bytes, bytes] = {}
1642
1643 f_multi: FileProtocol | FieldProtocol | None = None
1644 writer = None
1645 is_file = False
1646
1647 def on_part_begin() -> None:
1648 # Reset headers in case this isn't the first part.
1649 nonlocal headers
1650 headers = {}
1651
1652 def on_part_data(data: bytes, start: int, end: int) -> None:
1653 nonlocal writer
1654 assert writer is not None
1655 writer.write(data[start:end])
1656 # TODO: check for error here.
1657
1658 def on_part_end() -> None:
1659 nonlocal f_multi, is_file
1660 assert f_multi is not None
1661 f_multi.finalize()
1662 if is_file:
1663 if on_file:
1664 on_file(f_multi)
1665 else:
1666 if on_field:
1667 on_field(cast("FieldProtocol", f_multi))
1668
1669 def on_header_field(data: bytes, start: int, end: int) -> None:
1670 header_name.append(data[start:end])
1671
1672 def on_header_value(data: bytes, start: int, end: int) -> None:
1673 header_value.append(data[start:end])
1674
1675 def on_header_end() -> None:
1676 headers[b"".join(header_name)] = b"".join(header_value)
1677 del header_name[:]
1678 del header_value[:]
1679
1680 def on_headers_finished() -> None:
1681 nonlocal is_file, f_multi, writer
1682 # Reset the 'is file' flag.
1683 is_file = False
1684
1685 # Parse the content-disposition header.
1686 # TODO: handle mixed case
1687 content_disp = headers.get(b"Content-Disposition")
1688 disp, options = parse_options_header(content_disp)
1689
1690 # Get the field and filename.
1691 field_name = options.get(b"name")
1692 file_name = options.get(b"filename")
1693 # TODO: check for errors
1694
1695 # Create the proper class.
1696 if file_name is None:
1697 f_multi = FieldClass(field_name)
1698 else:
1699 f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
1700 is_file = True
1701
1702 # Parse the given Content-Transfer-Encoding to determine what
1703 # we need to do with the incoming data.
1704 # TODO: check that we properly handle 8bit / 7bit encoding.
1705 transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")
1706
1707 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1708 writer = f_multi
1709
1710 elif transfer_encoding == b"base64":
1711 writer = Base64Decoder(f_multi)
1712
1713 elif transfer_encoding == b"quoted-printable":
1714 writer = QuotedPrintableDecoder(f_multi)
1715
1716 else:
1717 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1718 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1719 raise FormParserError(f'Unknown Content-Transfer-Encoding "{transfer_encoding!r}"')
1720 else:
1721 # If we aren't erroring, then we just treat this as an
1722 # unencoded Content-Transfer-Encoding.
1723 writer = f_multi
1724
1725 def _on_end() -> None:
1726 nonlocal writer
1727 if writer is not None:
1728 writer.finalize()
1729 if self.on_end is not None:
1730 self.on_end()
1731
1732 # Instantiate a multipart parser.
1733 parser = MultipartParser(
1734 boundary,
1735 callbacks={
1736 "on_part_begin": on_part_begin,
1737 "on_part_data": on_part_data,
1738 "on_part_end": on_part_end,
1739 "on_header_field": on_header_field,
1740 "on_header_value": on_header_value,
1741 "on_header_end": on_header_end,
1742 "on_headers_finished": on_headers_finished,
1743 "on_end": _on_end,
1744 },
1745 max_size=self.config["MAX_BODY_SIZE"],
1746 )
1747
1748 else:
1749 self.logger.warning("Unknown Content-Type: %r", content_type)
1750 raise FormParserError(f"Unknown Content-Type: {content_type}")
1751
1752 self.parser = parser
1753
1754 def write(self, data: bytes) -> int:
1755 """Write some data. The parser will forward this to the appropriate
1756 underlying parser.
1757
1758 Args:
1759 data: The data to write.
1760
1761 Returns:
1762 The number of bytes processed.
1763 """
1764 self.bytes_received += len(data)
1765 # TODO: check the parser's return value for errors?
1766 assert self.parser is not None
1767 return self.parser.write(data)
1768
1769 def finalize(self) -> None:
1770 """Finalize the parser."""
1771 if self.parser is not None and hasattr(self.parser, "finalize"):
1772 self.parser.finalize()
1773
1774 def close(self) -> None:
1775 """Close the parser."""
1776 if self.parser is not None and hasattr(self.parser, "close"):
1777 self.parser.close()
1778
1779 def __repr__(self) -> str:
1780 return f"{self.__class__.__name__}(content_type={self.content_type!r}, parser={self.parser!r})"
1781
1782
1783def create_form_parser(
1784 headers: dict[str, bytes],
1785 on_field: OnFieldCallback | None,
1786 on_file: OnFileCallback | None,
1787 trust_x_headers: bool = False,
1788 config: dict[Any, Any] = {},
1789) -> FormParser:
1790 """This function is a helper function to aid in creating a FormParser
1791 instances. Given a dictionary-like headers object, it will determine
1792 the correct information needed, instantiate a FormParser with the
1793 appropriate values and given callbacks, and then return the corresponding
1794 parser.
1795
1796 Args:
1797 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1798 on_field: Callback to call with each parsed field.
1799 on_file: Callback to call with each parsed file.
1800 trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file
1801 name from X-File-Name.
1802 config: Configuration variables to pass to the FormParser.
1803 """
1804 content_type: str | bytes | None = headers.get("Content-Type")
1805 if content_type is None:
1806 logging.getLogger(__name__).warning("No Content-Type header given")
1807 raise ValueError("No Content-Type header given!")
1808
1809 # Boundaries are optional (the FormParser will raise if one is needed
1810 # but not given).
1811 content_type, params = parse_options_header(content_type)
1812 boundary = params.get(b"boundary")
1813
1814 # We need content_type to be a string, not a bytes object.
1815 content_type = content_type.decode("latin-1")
1816
1817 # File names are optional.
1818 file_name = headers.get("X-File-Name")
1819
1820 # Instantiate a form parser.
1821 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config)
1822
1823 # Return our parser.
1824 return form_parser
1825
1826
1827def parse_form(
1828 headers: dict[str, bytes],
1829 input_stream: SupportsRead,
1830 on_field: OnFieldCallback | None,
1831 on_file: OnFileCallback | None,
1832 chunk_size: int = 1048576,
1833) -> None:
1834 """This function is useful if you just want to parse a request body,
1835 without too much work. Pass it a dictionary-like object of the request's
1836 headers, and a file-like object for the input stream, along with two
1837 callbacks that will get called whenever a field or file is parsed.
1838
1839 Args:
1840 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1841 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1842 on_field: Callback to call with each parsed field.
1843 on_file: Callback to call with each parsed file.
1844 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1845 Defaults to 1 MiB.
1846 """
1847 # Create our form parser.
1848 parser = create_form_parser(headers, on_field, on_file)
1849
1850 # Read chunks of 1MiB and write to the parser, but never read more than
1851 # the given Content-Length, if any.
1852 content_length: int | float | bytes | None = headers.get("Content-Length")
1853 if content_length is not None:
1854 content_length = int(content_length)
1855 else:
1856 content_length = float("inf")
1857 bytes_read = 0
1858
1859 while True:
1860 # Read only up to the Content-Length given.
1861 max_readable = int(min(content_length - bytes_read, chunk_size))
1862 buff = input_stream.read(max_readable)
1863
1864 # Write to the parser and update our length.
1865 parser.write(buff)
1866 bytes_read += len(buff)
1867
1868 # If we get a buffer that's smaller than the size requested, or if we
1869 # have read up to our content length, we're done.
1870 if len(buff) != max_readable or bytes_read == content_length:
1871 break
1872
1873 # Tell our parser that we're done writing data.
1874 parser.finalize()