1from __future__ import annotations
2
3import logging
4import os
5import shutil
6import sys
7import tempfile
8from email.message import Message
9from enum import IntEnum
10from io import BufferedRandom, BytesIO
11from numbers import Number
12from typing import TYPE_CHECKING, cast
13
14from .decoders import Base64Decoder, QuotedPrintableDecoder
15from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
16
17if TYPE_CHECKING: # pragma: no cover
18 from typing import Any, Callable, Literal, Protocol, TypedDict
19
20 from typing_extensions import TypeAlias
21
22 class SupportsRead(Protocol):
23 def read(self, __n: int) -> bytes: ...
24
25 class QuerystringCallbacks(TypedDict, total=False):
26 on_field_start: Callable[[], None]
27 on_field_name: Callable[[bytes, int, int], None]
28 on_field_data: Callable[[bytes, int, int], None]
29 on_field_end: Callable[[], None]
30 on_end: Callable[[], None]
31
32 class OctetStreamCallbacks(TypedDict, total=False):
33 on_start: Callable[[], None]
34 on_data: Callable[[bytes, int, int], None]
35 on_end: Callable[[], None]
36
37 class MultipartCallbacks(TypedDict, total=False):
38 on_part_begin: Callable[[], None]
39 on_part_data: Callable[[bytes, int, int], None]
40 on_part_end: Callable[[], None]
41 on_header_begin: Callable[[], None]
42 on_header_field: Callable[[bytes, int, int], None]
43 on_header_value: Callable[[bytes, int, int], None]
44 on_header_end: Callable[[], None]
45 on_headers_finished: Callable[[], None]
46 on_end: Callable[[], None]
47
48 class FormParserConfig(TypedDict):
49 UPLOAD_DIR: str | None
50 UPLOAD_KEEP_FILENAME: bool
51 UPLOAD_KEEP_EXTENSIONS: bool
52 UPLOAD_ERROR_ON_BAD_CTE: bool
53 MAX_MEMORY_FILE_SIZE: int
54 MAX_BODY_SIZE: float
55
56 class FileConfig(TypedDict, total=False):
57 UPLOAD_DIR: str | bytes | None
58 UPLOAD_DELETE_TMP: bool
59 UPLOAD_KEEP_FILENAME: bool
60 UPLOAD_KEEP_EXTENSIONS: bool
61 MAX_MEMORY_FILE_SIZE: int
62
63 class _FormProtocol(Protocol):
64 def write(self, data: bytes) -> int: ...
65
66 def finalize(self) -> None: ...
67
68 def close(self) -> None: ...
69
70 class FieldProtocol(_FormProtocol, Protocol):
71 def __init__(self, name: bytes | None) -> None: ...
72
73 def set_none(self) -> None: ...
74
75 class FileProtocol(_FormProtocol, Protocol):
76 def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
77
78 OnFieldCallback = Callable[[FieldProtocol], None]
79 OnFileCallback = Callable[[FileProtocol], None]
80
81 CallbackName: TypeAlias = Literal[
82 "start",
83 "data",
84 "end",
85 "field_start",
86 "field_name",
87 "field_data",
88 "field_end",
89 "part_begin",
90 "part_data",
91 "part_end",
92 "header_begin",
93 "header_field",
94 "header_value",
95 "header_end",
96 "headers_finished",
97 ]
98
99# Unique missing object.
100_missing = object()
101
102
103class QuerystringState(IntEnum):
104 """Querystring parser states.
105
106 These are used to keep track of the state of the parser, and are used to determine
107 what to do when new data is encountered.
108 """
109
110 BEFORE_FIELD = 0
111 FIELD_NAME = 1
112 FIELD_DATA = 2
113
114
115class MultipartState(IntEnum):
116 """Multipart parser states.
117
118 These are used to keep track of the state of the parser, and are used to determine
119 what to do when new data is encountered.
120 """
121
122 START = 0
123 START_BOUNDARY = 1
124 HEADER_FIELD_START = 2
125 HEADER_FIELD = 3
126 HEADER_VALUE_START = 4
127 HEADER_VALUE = 5
128 HEADER_VALUE_ALMOST_DONE = 6
129 HEADERS_ALMOST_DONE = 7
130 PART_DATA_START = 8
131 PART_DATA = 9
132 PART_DATA_END = 10
133 END_BOUNDARY = 11
134 END = 12
135
136
137# Flags for the multipart parser.
138FLAG_PART_BOUNDARY = 1
139FLAG_LAST_BOUNDARY = 2
140
141# Get constants. Since iterating over a str on Python 2 gives you a 1-length
142# string, but iterating over a bytes object on Python 3 gives you an integer,
143# we need to save these constants.
144CR = b"\r"[0]
145LF = b"\n"[0]
146COLON = b":"[0]
147SPACE = b" "[0]
148HYPHEN = b"-"[0]
149AMPERSAND = b"&"[0]
150SEMICOLON = b";"[0]
151LOWER_A = b"a"[0]
152LOWER_Z = b"z"[0]
153NULL = b"\x00"[0]
154
155# fmt: off
156# Mask for ASCII characters that can be http tokens.
157# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
158# and these: !#$%&'*+-.^_`|~
159TOKEN_CHARS_SET = frozenset(
160 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
161 b"abcdefghijklmnopqrstuvwxyz"
162 b"0123456789"
163 b"!#$%&'*+-.^_`|~")
164# fmt: on
165
166
167def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
168 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
169 # Uses email.message.Message to parse the header as described in PEP 594.
170 # Ref: https://peps.python.org/pep-0594/#cgi
171 if not value:
172 return (b"", {})
173
174 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
175 if isinstance(value, bytes): # pragma: no cover
176 value = value.decode("latin-1")
177
178 # For types
179 assert isinstance(value, str), "Value should be a string by now"
180
181 # If we have no options, return the string as-is.
182 if ";" not in value:
183 return (value.lower().strip().encode("latin-1"), {})
184
185 # Split at the first semicolon, to get our value and then options.
186 # ctype, rest = value.split(b';', 1)
187 message = Message()
188 message["content-type"] = value
189 params = message.get_params()
190 # If there were no parameters, this would have already returned above
191 assert params, "At least the content type value should be present"
192 ctype = params.pop(0)[0].encode("latin-1")
193 options: dict[bytes, bytes] = {}
194 for param in params:
195 key, value = param
196 # If the value returned from get_params() is a 3-tuple, the last
197 # element corresponds to the value.
198 # See: https://docs.python.org/3/library/email.compat32-message.html
199 if isinstance(value, tuple):
200 value = value[-1]
201 # If the value is a filename, we need to fix a bug on IE6 that sends
202 # the full file path instead of the filename.
203 if key == "filename":
204 if value[1:3] == ":\\" or value[:2] == "\\\\":
205 value = value.split("\\")[-1]
206 options[key.encode("latin-1")] = value.encode("latin-1")
207 return ctype, options
208
209
210class Field:
211 """A Field object represents a (parsed) form field. It represents a single
212 field with a corresponding name and value.
213
214 The name that a :class:`Field` will be instantiated with is the same name
215 that would be found in the following HTML::
216
217 <input name="name_goes_here" type="text"/>
218
219 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
220 will be called when data is written to the Field, and when the Field is
221 finalized, respectively.
222
223 Args:
224 name: The name of the form field.
225 """
226
227 def __init__(self, name: bytes | None) -> None:
228 self._name = name
229 self._value: list[bytes] = []
230
231 # We cache the joined version of _value for speed.
232 self._cache = _missing
233
234 @classmethod
235 def from_value(cls, name: bytes, value: bytes | None) -> Field:
236 """Create an instance of a :class:`Field`, and set the corresponding
237 value - either None or an actual value. This method will also
238 finalize the Field itself.
239
240 Args:
241 name: the name of the form field.
242 value: the value of the form field - either a bytestring or None.
243
244 Returns:
245 A new instance of a [`Field`][python_multipart.Field].
246 """
247
248 f = cls(name)
249 if value is None:
250 f.set_none()
251 else:
252 f.write(value)
253 f.finalize()
254 return f
255
256 def write(self, data: bytes) -> int:
257 """Write some data into the form field.
258
259 Args:
260 data: The data to write to the field.
261
262 Returns:
263 The number of bytes written.
264 """
265 return self.on_data(data)
266
267 def on_data(self, data: bytes) -> int:
268 """This method is a callback that will be called whenever data is
269 written to the Field.
270
271 Args:
272 data: The data to write to the field.
273
274 Returns:
275 The number of bytes written.
276 """
277 self._value.append(data)
278 self._cache = _missing
279 return len(data)
280
281 def on_end(self) -> None:
282 """This method is called whenever the Field is finalized."""
283 if self._cache is _missing:
284 self._cache = b"".join(self._value)
285
286 def finalize(self) -> None:
287 """Finalize the form field."""
288 self.on_end()
289
290 def close(self) -> None:
291 """Close the Field object. This will free any underlying cache."""
292 # Free our value array.
293 if self._cache is _missing:
294 self._cache = b"".join(self._value)
295
296 del self._value
297
298 def set_none(self) -> None:
299 """Some fields in a querystring can possibly have a value of None - for
300 example, the string "foo&bar=&baz=asdf" will have a field with the
301 name "foo" and value None, one with name "bar" and value "", and one
302 with name "baz" and value "asdf". Since the write() interface doesn't
303 support writing None, this function will set the field value to None.
304 """
305 self._cache = None
306
307 @property
308 def field_name(self) -> bytes | None:
309 """This property returns the name of the field."""
310 return self._name
311
312 @property
313 def value(self) -> bytes | None:
314 """This property returns the value of the form field."""
315 if self._cache is _missing:
316 self._cache = b"".join(self._value)
317
318 assert isinstance(self._cache, bytes) or self._cache is None
319 return self._cache
320
321 def __eq__(self, other: object) -> bool:
322 if isinstance(other, Field):
323 return self.field_name == other.field_name and self.value == other.value
324 else:
325 return NotImplemented
326
327 def __repr__(self) -> str:
328 if self.value is not None and len(self.value) > 97:
329 # We get the repr, and then insert three dots before the final
330 # quote.
331 v = repr(self.value[:97])[:-1] + "...'"
332 else:
333 v = repr(self.value)
334
335 return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v)
336
337
338class File:
339 """This class represents an uploaded file. It handles writing file data to
340 either an in-memory file or a temporary file on-disk, if the optional
341 threshold is passed.
342
343 There are some options that can be passed to the File to change behavior
344 of the class. Valid options are as follows:
345
346 | Name | Type | Default | Description |
347 |-----------------------|-------|---------|-------------|
348 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
349 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
350 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
351 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
352 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
353
354 Args:
355 file_name: The name of the file that this [`File`][python_multipart.File] represents.
356 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
357 the file was uploaded with Content-Type application/octet-stream.
358 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
359 """ # noqa: E501
360
361 def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
362 # Save configuration, set other variables default.
363 self.logger = logging.getLogger(__name__)
364 self._config = config
365 self._in_memory = True
366 self._bytes_written = 0
367 self._fileobj: BytesIO | BufferedRandom = BytesIO()
368
369 # Save the provided field/file name.
370 self._field_name = field_name
371 self._file_name = file_name
372
373 # Our actual file name is None by default, since, depending on our
374 # config, we may not actually use the provided name.
375 self._actual_file_name: bytes | None = None
376
377 # Split the extension from the filename.
378 if file_name is not None:
379 base, ext = os.path.splitext(file_name)
380 self._file_base = base
381 self._ext = ext
382
383 @property
384 def field_name(self) -> bytes | None:
385 """The form field associated with this file. May be None if there isn't
386 one, for example when we have an application/octet-stream upload.
387 """
388 return self._field_name
389
390 @property
391 def file_name(self) -> bytes | None:
392 """The file name given in the upload request."""
393 return self._file_name
394
395 @property
396 def actual_file_name(self) -> bytes | None:
397 """The file name that this file is saved as. Will be None if it's not
398 currently saved on disk.
399 """
400 return self._actual_file_name
401
402 @property
403 def file_object(self) -> BytesIO | BufferedRandom:
404 """The file object that we're currently writing to. Note that this
405 will either be an instance of a :class:`io.BytesIO`, or a regular file
406 object.
407 """
408 return self._fileobj
409
410 @property
411 def size(self) -> int:
412 """The total size of this file, counted as the number of bytes that
413 currently have been written to the file.
414 """
415 return self._bytes_written
416
417 @property
418 def in_memory(self) -> bool:
419 """A boolean representing whether or not this file object is currently
420 stored in-memory or on-disk.
421 """
422 return self._in_memory
423
424 def flush_to_disk(self) -> None:
425 """If the file is already on-disk, do nothing. Otherwise, copy from
426 the in-memory buffer to a disk file, and then reassign our internal
427 file object to this new disk file.
428
429 Note that if you attempt to flush a file that is already on-disk, a
430 warning will be logged to this module's logger.
431 """
432 if not self._in_memory:
433 self.logger.warning("Trying to flush to disk when we're not in memory")
434 return
435
436 # Go back to the start of our file.
437 self._fileobj.seek(0)
438
439 # Open a new file.
440 new_file = self._get_disk_file()
441
442 # Copy the file objects.
443 shutil.copyfileobj(self._fileobj, new_file)
444
445 # Seek to the new position in our new file.
446 new_file.seek(self._bytes_written)
447
448 # Reassign the fileobject.
449 old_fileobj = self._fileobj
450 self._fileobj = new_file
451
452 # We're no longer in memory.
453 self._in_memory = False
454
455 # Close the old file object.
456 old_fileobj.close()
457
458 def _get_disk_file(self) -> BufferedRandom:
459 """This function is responsible for getting a file object on-disk for us."""
460 self.logger.info("Opening a file on disk")
461
462 file_dir = self._config.get("UPLOAD_DIR")
463 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
464 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
465 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
466 tmp_file: None | BufferedRandom = None
467
468 # If we have a directory and are to keep the filename...
469 if file_dir is not None and keep_filename:
470 self.logger.info("Saving with filename in: %r", file_dir)
471
472 # Build our filename.
473 # TODO: what happens if we don't have a filename?
474 fname = self._file_base + self._ext if keep_extensions else self._file_base
475
476 path = os.path.join(file_dir, fname) # type: ignore[arg-type]
477 try:
478 self.logger.info("Opening file: %r", path)
479 tmp_file = open(path, "w+b")
480 except OSError:
481 tmp_file = None
482
483 self.logger.exception("Error opening temporary file")
484 raise FileError("Error opening temporary file: %r" % path)
485 else:
486 # Build options array.
487 # Note that on Python 3, tempfile doesn't support byte names. We
488 # encode our paths using the default filesystem encoding.
489 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
490
491 if file_dir is None:
492 dir = None
493 elif isinstance(file_dir, bytes):
494 dir = file_dir.decode(sys.getfilesystemencoding())
495 else:
496 dir = file_dir # pragma: no cover
497
498 # Create a temporary (named) file with the appropriate settings.
499 self.logger.info(
500 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
501 )
502 try:
503 tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir))
504 except OSError:
505 self.logger.exception("Error creating named temporary file")
506 raise FileError("Error creating named temporary file")
507
508 assert tmp_file is not None
509 # Encode filename as bytes.
510 if isinstance(tmp_file.name, str):
511 fname = tmp_file.name.encode(sys.getfilesystemencoding())
512 else:
513 fname = cast(bytes, tmp_file.name) # pragma: no cover
514
515 self._actual_file_name = fname
516 return tmp_file
517
518 def write(self, data: bytes) -> int:
519 """Write some data to the File.
520
521 :param data: a bytestring
522 """
523 return self.on_data(data)
524
525 def on_data(self, data: bytes) -> int:
526 """This method is a callback that will be called whenever data is
527 written to the File.
528
529 Args:
530 data: The data to write to the file.
531
532 Returns:
533 The number of bytes written.
534 """
535 bwritten = self._fileobj.write(data)
536
537 # If the bytes written isn't the same as the length, just return.
538 if bwritten != len(data):
539 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
540 return bwritten
541
542 # Keep track of how many bytes we've written.
543 self._bytes_written += bwritten
544
545 # If we're in-memory and are over our limit, we create a file.
546 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
547 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
548 self.logger.info("Flushing to disk")
549 self.flush_to_disk()
550
551 # Return the number of bytes written.
552 return bwritten
553
554 def on_end(self) -> None:
555 """This method is called whenever the Field is finalized."""
556 # Flush the underlying file object
557 self._fileobj.flush()
558
559 def finalize(self) -> None:
560 """Finalize the form file. This will not close the underlying file,
561 but simply signal that we are finished writing to the File.
562 """
563 self.on_end()
564
565 def close(self) -> None:
566 """Close the File object. This will actually close the underlying
567 file object (whether it's a :class:`io.BytesIO` or an actual file
568 object).
569 """
570 self._fileobj.close()
571
572 def __repr__(self) -> str:
573 return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name)
574
575
576class BaseParser:
577 """This class is the base class for all parsers. It contains the logic for
578 calling and adding callbacks.
579
580 A callback can be one of two different forms. "Notification callbacks" are
581 callbacks that are called when something happens - for example, when a new
582 part of a multipart message is encountered by the parser. "Data callbacks"
583 are called when we get some sort of data - for example, part of the body of
584 a multipart chunk. Notification callbacks are called with no parameters,
585 whereas data callbacks are called with three, as follows::
586
587 data_callback(data, start, end)
588
589 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
590 Python 3). "start" and "end" are integer indexes into the "data" string
591 that represent the data of interest. Thus, in a data callback, the slice
592 `data[start:end]` represents the data that the callback is "interested in".
593 The callback is not passed a copy of the data, since copying severely hurts
594 performance.
595 """
596
597 def __init__(self) -> None:
598 self.logger = logging.getLogger(__name__)
599 self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {}
600
601 def callback(
602 self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None
603 ) -> None:
604 """This function calls a provided callback with some data. If the
605 callback is not set, will do nothing.
606
607 Args:
608 name: The name of the callback to call (as a string).
609 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
610 callback, and no parameters are given.
611 end: An integer that is passed to the data callback.
612 start: An integer that is passed to the data callback.
613 """
614 on_name = "on_" + name
615 func = self.callbacks.get(on_name)
616 if func is None:
617 return
618 func = cast("Callable[..., Any]", func)
619 # Depending on whether we're given a buffer...
620 if data is not None:
621 # Don't do anything if we have start == end.
622 if start is not None and start == end:
623 return
624
625 self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end)
626 func(data, start, end)
627 else:
628 self.logger.debug("Calling %s with no data", on_name)
629 func()
630
631 def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None:
632 """Update the function for a callback. Removes from the callbacks dict
633 if new_func is None.
634
635 :param name: The name of the callback to call (as a string).
636
637 :param new_func: The new function for the callback. If None, then the
638 callback will be removed (with no error if it does not
639 exist).
640 """
641 if new_func is None:
642 self.callbacks.pop("on_" + name, None) # type: ignore[misc]
643 else:
644 self.callbacks["on_" + name] = new_func # type: ignore[literal-required]
645
646 def close(self) -> None:
647 pass # pragma: no cover
648
649 def finalize(self) -> None:
650 pass # pragma: no cover
651
652 def __repr__(self) -> str:
653 return "%s()" % self.__class__.__name__
654
655
656class OctetStreamParser(BaseParser):
657 """This parser parses an octet-stream request body and calls callbacks when
658 incoming data is received. Callbacks are as follows:
659
660 | Callback Name | Parameters | Description |
661 |----------------|-----------------|-----------------------------------------------------|
662 | on_start | None | Called when the first data is parsed. |
663 | on_data | data, start, end| Called for each data chunk that is parsed. |
664 | on_end | None | Called when the parser is finished parsing all data.|
665
666 Args:
667 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
668 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
669 """
670
671 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
672 super().__init__()
673 self.callbacks = callbacks
674 self._started = False
675
676 if not isinstance(max_size, Number) or max_size < 1:
677 raise ValueError("max_size must be a positive number, not %r" % max_size)
678 self.max_size: int | float = max_size
679 self._current_size = 0
680
681 def write(self, data: bytes) -> int:
682 """Write some data to the parser, which will perform size verification,
683 and then pass the data to the underlying callback.
684
685 Args:
686 data: The data to write to the parser.
687
688 Returns:
689 The number of bytes written.
690 """
691 if not self._started:
692 self.callback("start")
693 self._started = True
694
695 # Truncate data length.
696 data_len = len(data)
697 if (self._current_size + data_len) > self.max_size:
698 # We truncate the length of data that we are to process.
699 new_size = int(self.max_size - self._current_size)
700 self.logger.warning(
701 "Current size is %d (max %d), so truncating data length from %d to %d",
702 self._current_size,
703 self.max_size,
704 data_len,
705 new_size,
706 )
707 data_len = new_size
708
709 # Increment size, then callback, in case there's an exception.
710 self._current_size += data_len
711 self.callback("data", data, 0, data_len)
712 return data_len
713
714 def finalize(self) -> None:
715 """Finalize this parser, which signals to that we are finished parsing,
716 and sends the on_end callback.
717 """
718 self.callback("end")
719
720 def __repr__(self) -> str:
721 return "%s()" % self.__class__.__name__
722
723
724class QuerystringParser(BaseParser):
725 """This is a streaming querystring parser. It will consume data, and call
726 the callbacks given when it has data.
727
728 | Callback Name | Parameters | Description |
729 |----------------|-----------------|-----------------------------------------------------|
730 | on_field_start | None | Called when a new field is encountered. |
731 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
732 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
733 | on_field_end | None | Called when the end of a field is encountered. |
734 | on_end | None | Called when the parser is finished parsing all data.|
735
736 Args:
737 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
738 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
739 behavior of the parser changes as the following: if a field has a value with an equal sign
740 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
741 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
742 then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised.
743 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
744 """ # noqa: E501
745
746 state: QuerystringState
747
748 def __init__(
749 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
750 ) -> None:
751 super().__init__()
752 self.state = QuerystringState.BEFORE_FIELD
753 self._found_sep = False
754
755 self.callbacks = callbacks
756
757 # Max-size stuff
758 if not isinstance(max_size, Number) or max_size < 1:
759 raise ValueError("max_size must be a positive number, not %r" % max_size)
760 self.max_size: int | float = max_size
761 self._current_size = 0
762
763 # Should parsing be strict?
764 self.strict_parsing = strict_parsing
765
766 def write(self, data: bytes) -> int:
767 """Write some data to the parser, which will perform size verification,
768 parse into either a field name or value, and then pass the
769 corresponding data to the underlying callback. If an error is
770 encountered while parsing, a QuerystringParseError will be raised. The
771 "offset" attribute of the raised exception will be set to the offset in
772 the input data chunk (NOT the overall stream) that caused the error.
773
774 Args:
775 data: The data to write to the parser.
776
777 Returns:
778 The number of bytes written.
779 """
780 # Handle sizing.
781 data_len = len(data)
782 if (self._current_size + data_len) > self.max_size:
783 # We truncate the length of data that we are to process.
784 new_size = int(self.max_size - self._current_size)
785 self.logger.warning(
786 "Current size is %d (max %d), so truncating data length from %d to %d",
787 self._current_size,
788 self.max_size,
789 data_len,
790 new_size,
791 )
792 data_len = new_size
793
794 l = 0
795 try:
796 l = self._internal_write(data, data_len)
797 finally:
798 self._current_size += l
799
800 return l
801
802 def _internal_write(self, data: bytes, length: int) -> int:
803 state = self.state
804 strict_parsing = self.strict_parsing
805 found_sep = self._found_sep
806
807 i = 0
808 while i < length:
809 ch = data[i]
810
811 # Depending on our state...
812 if state == QuerystringState.BEFORE_FIELD:
813 # If the 'found_sep' flag is set, we've already encountered
814 # and skipped a single separator. If so, we check our strict
815 # parsing flag and decide what to do. Otherwise, we haven't
816 # yet reached a separator, and thus, if we do, we need to skip
817 # it as it will be the boundary between fields that's supposed
818 # to be there.
819 if ch == AMPERSAND or ch == SEMICOLON:
820 if found_sep:
821 # If we're parsing strictly, we disallow blank chunks.
822 if strict_parsing:
823 e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i)
824 e.offset = i
825 raise e
826 else:
827 self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
828 else:
829 # This case is when we're skipping the (first)
830 # separator between fields, so we just set our flag
831 # and continue on.
832 found_sep = True
833 else:
834 # Emit a field-start event, and go to that state. Also,
835 # reset the "found_sep" flag, for the next time we get to
836 # this state.
837 self.callback("field_start")
838 i -= 1
839 state = QuerystringState.FIELD_NAME
840 found_sep = False
841
842 elif state == QuerystringState.FIELD_NAME:
843 # Try and find a separator - we ensure that, if we do, we only
844 # look for the equal sign before it.
845 sep_pos = data.find(b"&", i)
846 if sep_pos == -1:
847 sep_pos = data.find(b";", i)
848
849 # See if we can find an equals sign in the remaining data. If
850 # so, we can immediately emit the field name and jump to the
851 # data state.
852 if sep_pos != -1:
853 equals_pos = data.find(b"=", i, sep_pos)
854 else:
855 equals_pos = data.find(b"=", i)
856
857 if equals_pos != -1:
858 # Emit this name.
859 self.callback("field_name", data, i, equals_pos)
860
861 # Jump i to this position. Note that it will then have 1
862 # added to it below, which means the next iteration of this
863 # loop will inspect the character after the equals sign.
864 i = equals_pos
865 state = QuerystringState.FIELD_DATA
866 else:
867 # No equals sign found.
868 if not strict_parsing:
869 # See also comments in the QuerystringState.FIELD_DATA case below.
870 # If we found the separator, we emit the name and just
871 # end - there's no data callback at all (not even with
872 # a blank value).
873 if sep_pos != -1:
874 self.callback("field_name", data, i, sep_pos)
875 self.callback("field_end")
876
877 i = sep_pos - 1
878 state = QuerystringState.BEFORE_FIELD
879 else:
880 # Otherwise, no separator in this block, so the
881 # rest of this chunk must be a name.
882 self.callback("field_name", data, i, length)
883 i = length
884
885 else:
886 # We're parsing strictly. If we find a separator,
887 # this is an error - we require an equals sign.
888 if sep_pos != -1:
889 e = QuerystringParseError(
890 "When strict_parsing is True, we require an "
891 "equals sign in all field chunks. Did not "
892 "find one in the chunk that starts at %d" % (i,)
893 )
894 e.offset = i
895 raise e
896
897 # No separator in the rest of this chunk, so it's just
898 # a field name.
899 self.callback("field_name", data, i, length)
900 i = length
901
902 elif state == QuerystringState.FIELD_DATA:
903 # Try finding either an ampersand or a semicolon after this
904 # position.
905 sep_pos = data.find(b"&", i)
906 if sep_pos == -1:
907 sep_pos = data.find(b";", i)
908
909 # If we found it, callback this bit as data and then go back
910 # to expecting to find a field.
911 if sep_pos != -1:
912 self.callback("field_data", data, i, sep_pos)
913 self.callback("field_end")
914
915 # Note that we go to the separator, which brings us to the
916 # "before field" state. This allows us to properly emit
917 # "field_start" events only when we actually have data for
918 # a field of some sort.
919 i = sep_pos - 1
920 state = QuerystringState.BEFORE_FIELD
921
922 # Otherwise, emit the rest as data and finish.
923 else:
924 self.callback("field_data", data, i, length)
925 i = length
926
927 else: # pragma: no cover (error case)
928 msg = "Reached an unknown state %d at %d" % (state, i)
929 self.logger.warning(msg)
930 e = QuerystringParseError(msg)
931 e.offset = i
932 raise e
933
934 i += 1
935
936 self.state = state
937 self._found_sep = found_sep
938 return len(data)
939
940 def finalize(self) -> None:
941 """Finalize this parser, which signals to that we are finished parsing,
942 if we're still in the middle of a field, an on_field_end callback, and
943 then the on_end callback.
944 """
945 # If we're currently in the middle of a field, we finish it.
946 if self.state == QuerystringState.FIELD_DATA:
947 self.callback("field_end")
948 self.callback("end")
949
950 def __repr__(self) -> str:
951 return "{}(strict_parsing={!r}, max_size={!r})".format(
952 self.__class__.__name__, self.strict_parsing, self.max_size
953 )
954
955
956class MultipartParser(BaseParser):
957 """This class is a streaming multipart/form-data parser.
958
959 | Callback Name | Parameters | Description |
960 |--------------------|-----------------|-------------|
961 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
962 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
963 | on_part_end | None | Called when the end of a part is reached. |
964 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
965 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
966 | on_header_value | data, start, end| Called when we get data for a header. |
967 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
968 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
969 | on_end | None | Called when the parser is finished parsing all data. |
970
971 Args:
972 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
973 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser].
974 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
975 """ # noqa: E501
976
977 def __init__(
978 self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
979 ) -> None:
980 # Initialize parser state.
981 super().__init__()
982 self.state = MultipartState.START
983 self.index = self.flags = 0
984
985 self.callbacks = callbacks
986
987 if not isinstance(max_size, Number) or max_size < 1:
988 raise ValueError("max_size must be a positive number, not %r" % max_size)
989 self.max_size = max_size
990 self._current_size = 0
991
992 # Setup marks. These are used to track the state of data received.
993 self.marks: dict[str, int] = {}
994
995 # Save our boundary.
996 if isinstance(boundary, str): # pragma: no cover
997 boundary = boundary.encode("latin-1")
998 self.boundary = b"\r\n--" + boundary
999
1000 def write(self, data: bytes) -> int:
1001 """Write some data to the parser, which will perform size verification,
1002 and then parse the data into the appropriate location (e.g. header,
1003 data, etc.), and pass this on to the underlying callback. If an error
1004 is encountered, a MultipartParseError will be raised. The "offset"
1005 attribute on the raised exception will be set to the offset of the byte
1006 in the input chunk that caused the error.
1007
1008 Args:
1009 data: The data to write to the parser.
1010
1011 Returns:
1012 The number of bytes written.
1013 """
1014 # Handle sizing.
1015 data_len = len(data)
1016 if (self._current_size + data_len) > self.max_size:
1017 # We truncate the length of data that we are to process.
1018 new_size = int(self.max_size - self._current_size)
1019 self.logger.warning(
1020 "Current size is %d (max %d), so truncating data length from %d to %d",
1021 self._current_size,
1022 self.max_size,
1023 data_len,
1024 new_size,
1025 )
1026 data_len = new_size
1027
1028 l = 0
1029 try:
1030 l = self._internal_write(data, data_len)
1031 finally:
1032 self._current_size += l
1033
1034 return l
1035
1036 def _internal_write(self, data: bytes, length: int) -> int:
1037 # Get values from locals.
1038 boundary = self.boundary
1039
1040 # Get our state, flags and index. These are persisted between calls to
1041 # this function.
1042 state = self.state
1043 index = self.index
1044 flags = self.flags
1045
1046 # Our index defaults to 0.
1047 i = 0
1048
1049 # Set a mark.
1050 def set_mark(name: str) -> None:
1051 self.marks[name] = i
1052
1053 # Remove a mark.
1054 def delete_mark(name: str, reset: bool = False) -> None:
1055 self.marks.pop(name, None)
1056
1057 # Helper function that makes calling a callback with data easier. The
1058 # 'remaining' parameter will callback from the marked value until the
1059 # end of the buffer, and reset the mark, instead of deleting it. This
1060 # is used at the end of the function to call our callbacks with any
1061 # remaining data in this chunk.
1062 def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None:
1063 marked_index = self.marks.get(name)
1064 if marked_index is None:
1065 return
1066
1067 # Otherwise, we call it from the mark to the current byte we're
1068 # processing.
1069 if end_i <= marked_index:
1070 # There is no additional data to send.
1071 pass
1072 elif marked_index >= 0:
1073 # We are emitting data from the local buffer.
1074 self.callback(name, data, marked_index, end_i)
1075 else:
1076 # Some of the data comes from a partial boundary match.
1077 # and requires look-behind.
1078 # We need to use self.flags (and not flags) because we care about
1079 # the state when we entered the loop.
1080 lookbehind_len = -marked_index
1081 if lookbehind_len <= len(boundary):
1082 self.callback(name, boundary, 0, lookbehind_len)
1083 elif self.flags & FLAG_PART_BOUNDARY:
1084 lookback = boundary + b"\r\n"
1085 self.callback(name, lookback, 0, lookbehind_len)
1086 elif self.flags & FLAG_LAST_BOUNDARY:
1087 lookback = boundary + b"--\r\n"
1088 self.callback(name, lookback, 0, lookbehind_len)
1089 else: # pragma: no cover (error case)
1090 self.logger.warning("Look-back buffer error")
1091
1092 if end_i > 0:
1093 self.callback(name, data, 0, end_i)
1094 # If we're getting remaining data, we have got all the data we
1095 # can be certain is not a boundary, leaving only a partial boundary match.
1096 if remaining:
1097 self.marks[name] = end_i - length
1098 else:
1099 self.marks.pop(name, None)
1100
1101 # For each byte...
1102 while i < length:
1103 c = data[i]
1104
1105 if state == MultipartState.START:
1106 # Skip leading newlines
1107 if c == CR or c == LF:
1108 i += 1
1109 continue
1110
1111 # index is used as in index into our boundary. Set to 0.
1112 index = 0
1113
1114 # Move to the next state, but decrement i so that we re-process
1115 # this character.
1116 state = MultipartState.START_BOUNDARY
1117 i -= 1
1118
1119 elif state == MultipartState.START_BOUNDARY:
1120 # Check to ensure that the last 2 characters in our boundary
1121 # are CRLF.
1122 if index == len(boundary) - 2:
1123 if c == HYPHEN:
1124 # Potential empty message.
1125 state = MultipartState.END_BOUNDARY
1126 elif c != CR:
1127 # Error!
1128 msg = "Did not find CR at end of boundary (%d)" % (i,)
1129 self.logger.warning(msg)
1130 e = MultipartParseError(msg)
1131 e.offset = i
1132 raise e
1133
1134 index += 1
1135
1136 elif index == len(boundary) - 2 + 1:
1137 if c != LF:
1138 msg = "Did not find LF at end of boundary (%d)" % (i,)
1139 self.logger.warning(msg)
1140 e = MultipartParseError(msg)
1141 e.offset = i
1142 raise e
1143
1144 # The index is now used for indexing into our boundary.
1145 index = 0
1146
1147 # Callback for the start of a part.
1148 self.callback("part_begin")
1149
1150 # Move to the next character and state.
1151 state = MultipartState.HEADER_FIELD_START
1152
1153 else:
1154 # Check to ensure our boundary matches
1155 if c != boundary[index + 2]:
1156 msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2)
1157 self.logger.warning(msg)
1158 e = MultipartParseError(msg)
1159 e.offset = i
1160 raise e
1161
1162 # Increment index into boundary and continue.
1163 index += 1
1164
1165 elif state == MultipartState.HEADER_FIELD_START:
1166 # Mark the start of a header field here, reset the index, and
1167 # continue parsing our header field.
1168 index = 0
1169
1170 # Set a mark of our header field.
1171 set_mark("header_field")
1172
1173 # Notify that we're starting a header if the next character is
1174 # not a CR; a CR at the beginning of the header will cause us
1175 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1176 # below.
1177 if c != CR:
1178 self.callback("header_begin")
1179
1180 # Move to parsing header fields.
1181 state = MultipartState.HEADER_FIELD
1182 i -= 1
1183
1184 elif state == MultipartState.HEADER_FIELD:
1185 # If we've reached a CR at the beginning of a header, it means
1186 # that we've reached the second of 2 newlines, and so there are
1187 # no more headers to parse.
1188 if c == CR and index == 0:
1189 delete_mark("header_field")
1190 state = MultipartState.HEADERS_ALMOST_DONE
1191 i += 1
1192 continue
1193
1194 # Increment our index in the header.
1195 index += 1
1196
1197 # If we've reached a colon, we're done with this header.
1198 if c == COLON:
1199 # A 0-length header is an error.
1200 if index == 1:
1201 msg = "Found 0-length header at %d" % (i,)
1202 self.logger.warning(msg)
1203 e = MultipartParseError(msg)
1204 e.offset = i
1205 raise e
1206
1207 # Call our callback with the header field.
1208 data_callback("header_field", i)
1209
1210 # Move to parsing the header value.
1211 state = MultipartState.HEADER_VALUE_START
1212
1213 elif c not in TOKEN_CHARS_SET:
1214 msg = "Found invalid character %r in header at %d" % (c, i)
1215 self.logger.warning(msg)
1216 e = MultipartParseError(msg)
1217 e.offset = i
1218 raise e
1219
1220 elif state == MultipartState.HEADER_VALUE_START:
1221 # Skip leading spaces.
1222 if c == SPACE:
1223 i += 1
1224 continue
1225
1226 # Mark the start of the header value.
1227 set_mark("header_value")
1228
1229 # Move to the header-value state, reprocessing this character.
1230 state = MultipartState.HEADER_VALUE
1231 i -= 1
1232
1233 elif state == MultipartState.HEADER_VALUE:
1234 # If we've got a CR, we're nearly done our headers. Otherwise,
1235 # we do nothing and just move past this character.
1236 if c == CR:
1237 data_callback("header_value", i)
1238 self.callback("header_end")
1239 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1240
1241 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1242 # The last character should be a LF. If not, it's an error.
1243 if c != LF:
1244 msg = "Did not find LF character at end of header " "(found %r)" % (c,)
1245 self.logger.warning(msg)
1246 e = MultipartParseError(msg)
1247 e.offset = i
1248 raise e
1249
1250 # Move back to the start of another header. Note that if that
1251 # state detects ANOTHER newline, it'll trigger the end of our
1252 # headers.
1253 state = MultipartState.HEADER_FIELD_START
1254
1255 elif state == MultipartState.HEADERS_ALMOST_DONE:
1256 # We're almost done our headers. This is reached when we parse
1257 # a CR at the beginning of a header, so our next character
1258 # should be a LF, or it's an error.
1259 if c != LF:
1260 msg = f"Did not find LF at end of headers (found {c!r})"
1261 self.logger.warning(msg)
1262 e = MultipartParseError(msg)
1263 e.offset = i
1264 raise e
1265
1266 self.callback("headers_finished")
1267 state = MultipartState.PART_DATA_START
1268
1269 elif state == MultipartState.PART_DATA_START:
1270 # Mark the start of our part data.
1271 set_mark("part_data")
1272
1273 # Start processing part data, including this character.
1274 state = MultipartState.PART_DATA
1275 i -= 1
1276
1277 elif state == MultipartState.PART_DATA:
1278 # We're processing our part data right now. During this, we
1279 # need to efficiently search for our boundary, since any data
1280 # on any number of lines can be a part of the current data.
1281
1282 # Save the current value of our index. We use this in case we
1283 # find part of a boundary, but it doesn't match fully.
1284 prev_index = index
1285
1286 # Set up variables.
1287 boundary_length = len(boundary)
1288 data_length = length
1289
1290 # If our index is 0, we're starting a new part, so start our
1291 # search.
1292 if index == 0:
1293 # The most common case is likely to be that the whole
1294 # boundary is present in the buffer.
1295 # Calling `find` is much faster than iterating here.
1296 i0 = data.find(boundary, i, data_length)
1297 if i0 >= 0:
1298 # We matched the whole boundary string.
1299 index = boundary_length - 1
1300 i = i0 + boundary_length - 1
1301 else:
1302 # No match found for whole string.
1303 # There may be a partial boundary at the end of the
1304 # data, which the find will not match.
1305 # Since the length should to be searched is limited to
1306 # the boundary length, just perform a naive search.
1307 i = max(i, data_length - boundary_length)
1308
1309 # Search forward until we either hit the end of our buffer,
1310 # or reach a potential start of the boundary.
1311 while i < data_length - 1 and data[i] != boundary[0]:
1312 i += 1
1313
1314 c = data[i]
1315
1316 # Now, we have a couple of cases here. If our index is before
1317 # the end of the boundary...
1318 if index < boundary_length:
1319 # If the character matches...
1320 if boundary[index] == c:
1321 # The current character matches, so continue!
1322 index += 1
1323 else:
1324 index = 0
1325
1326 # Our index is equal to the length of our boundary!
1327 elif index == boundary_length:
1328 # First we increment it.
1329 index += 1
1330
1331 # Now, if we've reached a newline, we need to set this as
1332 # the potential end of our boundary.
1333 if c == CR:
1334 flags |= FLAG_PART_BOUNDARY
1335
1336 # Otherwise, if this is a hyphen, we might be at the last
1337 # of all boundaries.
1338 elif c == HYPHEN:
1339 flags |= FLAG_LAST_BOUNDARY
1340
1341 # Otherwise, we reset our index, since this isn't either a
1342 # newline or a hyphen.
1343 else:
1344 index = 0
1345
1346 # Our index is right after the part boundary, which should be
1347 # a LF.
1348 elif index == boundary_length + 1:
1349 # If we're at a part boundary (i.e. we've seen a CR
1350 # character already)...
1351 if flags & FLAG_PART_BOUNDARY:
1352 # We need a LF character next.
1353 if c == LF:
1354 # Unset the part boundary flag.
1355 flags &= ~FLAG_PART_BOUNDARY
1356
1357 # We have identified a boundary, callback for any data before it.
1358 data_callback("part_data", i - index)
1359 # Callback indicating that we've reached the end of
1360 # a part, and are starting a new one.
1361 self.callback("part_end")
1362 self.callback("part_begin")
1363
1364 # Move to parsing new headers.
1365 index = 0
1366 state = MultipartState.HEADER_FIELD_START
1367 i += 1
1368 continue
1369
1370 # We didn't find an LF character, so no match. Reset
1371 # our index and clear our flag.
1372 index = 0
1373 flags &= ~FLAG_PART_BOUNDARY
1374
1375 # Otherwise, if we're at the last boundary (i.e. we've
1376 # seen a hyphen already)...
1377 elif flags & FLAG_LAST_BOUNDARY:
1378 # We need a second hyphen here.
1379 if c == HYPHEN:
1380 # We have identified a boundary, callback for any data before it.
1381 data_callback("part_data", i - index)
1382 # Callback to end the current part, and then the
1383 # message.
1384 self.callback("part_end")
1385 self.callback("end")
1386 state = MultipartState.END
1387 else:
1388 # No match, so reset index.
1389 index = 0
1390
1391 # Otherwise, our index is 0. If the previous index is not, it
1392 # means we reset something, and we need to take the data we
1393 # thought was part of our boundary and send it along as actual
1394 # data.
1395 if index == 0 and prev_index > 0:
1396 # Overwrite our previous index.
1397 prev_index = 0
1398
1399 # Re-consider the current character, since this could be
1400 # the start of the boundary itself.
1401 i -= 1
1402
1403 elif state == MultipartState.END_BOUNDARY:
1404 if index == len(boundary) - 2 + 1:
1405 if c != HYPHEN:
1406 msg = "Did not find - at end of boundary (%d)" % (i,)
1407 self.logger.warning(msg)
1408 e = MultipartParseError(msg)
1409 e.offset = i
1410 raise e
1411 index += 1
1412 self.callback("end")
1413 state = MultipartState.END
1414
1415 elif state == MultipartState.END:
1416 # Don't do anything if chunk ends with CRLF.
1417 if c == CR and i + 1 < length and data[i + 1] == LF:
1418 i += 2
1419 continue
1420 # Skip data after the last boundary.
1421 self.logger.warning("Skipping data after last boundary")
1422 i = length
1423 break
1424
1425 else: # pragma: no cover (error case)
1426 # We got into a strange state somehow! Just stop processing.
1427 msg = "Reached an unknown state %d at %d" % (state, i)
1428 self.logger.warning(msg)
1429 e = MultipartParseError(msg)
1430 e.offset = i
1431 raise e
1432
1433 # Move to the next byte.
1434 i += 1
1435
1436 # We call our callbacks with any remaining data. Note that we pass
1437 # the 'remaining' flag, which sets the mark back to 0 instead of
1438 # deleting it, if it's found. This is because, if the mark is found
1439 # at this point, we assume that there's data for one of these things
1440 # that has been parsed, but not yet emitted. And, as such, it implies
1441 # that we haven't yet reached the end of this 'thing'. So, by setting
1442 # the mark to 0, we cause any data callbacks that take place in future
1443 # calls to this function to start from the beginning of that buffer.
1444 data_callback("header_field", length, True)
1445 data_callback("header_value", length, True)
1446 data_callback("part_data", length - index, True)
1447
1448 # Save values to locals.
1449 self.state = state
1450 self.index = index
1451 self.flags = flags
1452
1453 # Return our data length to indicate no errors, and that we processed
1454 # all of it.
1455 return length
1456
1457 def finalize(self) -> None:
1458 """Finalize this parser, which signals to that we are finished parsing.
1459
1460 Note: It does not currently, but in the future, it will verify that we
1461 are in the final state of the parser (i.e. the end of the multipart
1462 message is well-formed), and, if not, throw an error.
1463 """
1464 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1465 # error or otherwise state that we're not finished parsing.
1466 pass
1467
1468 def __repr__(self) -> str:
1469 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1470
1471
1472class FormParser:
1473 """This class is the all-in-one form parser. Given all the information
1474 necessary to parse a form, it will instantiate the correct parser, create
1475 the proper :class:`Field` and :class:`File` classes to store the data that
1476 is parsed, and call the two given callbacks with each field and file as
1477 they become available.
1478
1479 Args:
1480 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1481 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1482 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1483 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1484 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1485 in the Content-Type header, as a bytestring.
1486 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1487 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1488 manually.
1489 FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class
1490 if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and
1491 it must provide the following functions::
1492 - file_instance.write(data)
1493 - file_instance.finalize()
1494 - file_instance.close()
1495 FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own
1496 class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it
1497 must provide the following functions::
1498 - field_instance.write(data)
1499 - field_instance.finalize()
1500 - field_instance.close()
1501 - field_instance.set_none()
1502 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1503 and then any keys present in this dictionary will overwrite the default values.
1504 """
1505
1506 #: This is the default configuration for our form parser.
1507 #: Note: all file sizes should be in bytes.
1508 DEFAULT_CONFIG: FormParserConfig = {
1509 "MAX_BODY_SIZE": float("inf"),
1510 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1511 "UPLOAD_DIR": None,
1512 "UPLOAD_KEEP_FILENAME": False,
1513 "UPLOAD_KEEP_EXTENSIONS": False,
1514 # Error on invalid Content-Transfer-Encoding?
1515 "UPLOAD_ERROR_ON_BAD_CTE": False,
1516 }
1517
1518 def __init__(
1519 self,
1520 content_type: str,
1521 on_field: OnFieldCallback | None,
1522 on_file: OnFileCallback | None,
1523 on_end: Callable[[], None] | None = None,
1524 boundary: bytes | str | None = None,
1525 file_name: bytes | None = None,
1526 FileClass: type[FileProtocol] = File,
1527 FieldClass: type[FieldProtocol] = Field,
1528 config: dict[Any, Any] = {},
1529 ) -> None:
1530 self.logger = logging.getLogger(__name__)
1531
1532 # Save variables.
1533 self.content_type = content_type
1534 self.boundary = boundary
1535 self.bytes_received = 0
1536 self.parser = None
1537
1538 # Save callbacks.
1539 self.on_field = on_field
1540 self.on_file = on_file
1541 self.on_end = on_end
1542
1543 # Save classes.
1544 self.FileClass = File
1545 self.FieldClass = Field
1546
1547 # Set configuration options.
1548 self.config: FormParserConfig = self.DEFAULT_CONFIG.copy()
1549 self.config.update(config) # type: ignore[typeddict-item]
1550
1551 parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None
1552
1553 # Depending on the Content-Type, we instantiate the correct parser.
1554 if content_type == "application/octet-stream":
1555 file: FileProtocol = None # type: ignore
1556
1557 def on_start() -> None:
1558 nonlocal file
1559 file = FileClass(file_name, None, config=cast("FileConfig", self.config))
1560
1561 def on_data(data: bytes, start: int, end: int) -> None:
1562 nonlocal file
1563 file.write(data[start:end])
1564
1565 def _on_end() -> None:
1566 nonlocal file
1567 # Finalize the file itself.
1568 file.finalize()
1569
1570 # Call our callback.
1571 if on_file:
1572 on_file(file)
1573
1574 # Call the on-end callback.
1575 if self.on_end is not None:
1576 self.on_end()
1577
1578 # Instantiate an octet-stream parser
1579 parser = OctetStreamParser(
1580 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1581 max_size=self.config["MAX_BODY_SIZE"],
1582 )
1583
1584 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1585 name_buffer: list[bytes] = []
1586
1587 f: FieldProtocol | None = None
1588
1589 def on_field_start() -> None:
1590 pass
1591
1592 def on_field_name(data: bytes, start: int, end: int) -> None:
1593 name_buffer.append(data[start:end])
1594
1595 def on_field_data(data: bytes, start: int, end: int) -> None:
1596 nonlocal f
1597 if f is None:
1598 f = FieldClass(b"".join(name_buffer))
1599 del name_buffer[:]
1600 f.write(data[start:end])
1601
1602 def on_field_end() -> None:
1603 nonlocal f
1604 # Finalize and call callback.
1605 if f is None:
1606 # If we get here, it's because there was no field data.
1607 # We create a field, set it to None, and then continue.
1608 f = FieldClass(b"".join(name_buffer))
1609 del name_buffer[:]
1610 f.set_none()
1611
1612 f.finalize()
1613 if on_field:
1614 on_field(f)
1615 f = None
1616
1617 def _on_end() -> None:
1618 if self.on_end is not None:
1619 self.on_end()
1620
1621 # Instantiate parser.
1622 parser = QuerystringParser(
1623 callbacks={
1624 "on_field_start": on_field_start,
1625 "on_field_name": on_field_name,
1626 "on_field_data": on_field_data,
1627 "on_field_end": on_field_end,
1628 "on_end": _on_end,
1629 },
1630 max_size=self.config["MAX_BODY_SIZE"],
1631 )
1632
1633 elif content_type == "multipart/form-data":
1634 if boundary is None:
1635 self.logger.error("No boundary given")
1636 raise FormParserError("No boundary given")
1637
1638 header_name: list[bytes] = []
1639 header_value: list[bytes] = []
1640 headers: dict[bytes, bytes] = {}
1641
1642 f_multi: FileProtocol | FieldProtocol | None = None
1643 writer = None
1644 is_file = False
1645
1646 def on_part_begin() -> None:
1647 # Reset headers in case this isn't the first part.
1648 nonlocal headers
1649 headers = {}
1650
1651 def on_part_data(data: bytes, start: int, end: int) -> None:
1652 nonlocal writer
1653 assert writer is not None
1654 writer.write(data[start:end])
1655 # TODO: check for error here.
1656
1657 def on_part_end() -> None:
1658 nonlocal f_multi, is_file
1659 assert f_multi is not None
1660 f_multi.finalize()
1661 if is_file:
1662 if on_file:
1663 on_file(f_multi)
1664 else:
1665 if on_field:
1666 on_field(cast("FieldProtocol", f_multi))
1667
1668 def on_header_field(data: bytes, start: int, end: int) -> None:
1669 header_name.append(data[start:end])
1670
1671 def on_header_value(data: bytes, start: int, end: int) -> None:
1672 header_value.append(data[start:end])
1673
1674 def on_header_end() -> None:
1675 headers[b"".join(header_name)] = b"".join(header_value)
1676 del header_name[:]
1677 del header_value[:]
1678
1679 def on_headers_finished() -> None:
1680 nonlocal is_file, f_multi, writer
1681 # Reset the 'is file' flag.
1682 is_file = False
1683
1684 # Parse the content-disposition header.
1685 # TODO: handle mixed case
1686 content_disp = headers.get(b"Content-Disposition")
1687 disp, options = parse_options_header(content_disp)
1688
1689 # Get the field and filename.
1690 field_name = options.get(b"name")
1691 file_name = options.get(b"filename")
1692 # TODO: check for errors
1693
1694 # Create the proper class.
1695 if file_name is None:
1696 f_multi = FieldClass(field_name)
1697 else:
1698 f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
1699 is_file = True
1700
1701 # Parse the given Content-Transfer-Encoding to determine what
1702 # we need to do with the incoming data.
1703 # TODO: check that we properly handle 8bit / 7bit encoding.
1704 transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")
1705
1706 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1707 writer = f_multi
1708
1709 elif transfer_encoding == b"base64":
1710 writer = Base64Decoder(f_multi)
1711
1712 elif transfer_encoding == b"quoted-printable":
1713 writer = QuotedPrintableDecoder(f_multi)
1714
1715 else:
1716 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1717 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1718 raise FormParserError('Unknown Content-Transfer-Encoding "{!r}"'.format(transfer_encoding))
1719 else:
1720 # If we aren't erroring, then we just treat this as an
1721 # unencoded Content-Transfer-Encoding.
1722 writer = f_multi
1723
1724 def _on_end() -> None:
1725 nonlocal writer
1726 if writer is not None:
1727 writer.finalize()
1728 if self.on_end is not None:
1729 self.on_end()
1730
1731 # Instantiate a multipart parser.
1732 parser = MultipartParser(
1733 boundary,
1734 callbacks={
1735 "on_part_begin": on_part_begin,
1736 "on_part_data": on_part_data,
1737 "on_part_end": on_part_end,
1738 "on_header_field": on_header_field,
1739 "on_header_value": on_header_value,
1740 "on_header_end": on_header_end,
1741 "on_headers_finished": on_headers_finished,
1742 "on_end": _on_end,
1743 },
1744 max_size=self.config["MAX_BODY_SIZE"],
1745 )
1746
1747 else:
1748 self.logger.warning("Unknown Content-Type: %r", content_type)
1749 raise FormParserError("Unknown Content-Type: {}".format(content_type))
1750
1751 self.parser = parser
1752
1753 def write(self, data: bytes) -> int:
1754 """Write some data. The parser will forward this to the appropriate
1755 underlying parser.
1756
1757 Args:
1758 data: The data to write.
1759
1760 Returns:
1761 The number of bytes processed.
1762 """
1763 self.bytes_received += len(data)
1764 # TODO: check the parser's return value for errors?
1765 assert self.parser is not None
1766 return self.parser.write(data)
1767
1768 def finalize(self) -> None:
1769 """Finalize the parser."""
1770 if self.parser is not None and hasattr(self.parser, "finalize"):
1771 self.parser.finalize()
1772
1773 def close(self) -> None:
1774 """Close the parser."""
1775 if self.parser is not None and hasattr(self.parser, "close"):
1776 self.parser.close()
1777
1778 def __repr__(self) -> str:
1779 return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)
1780
1781
1782def create_form_parser(
1783 headers: dict[str, bytes],
1784 on_field: OnFieldCallback | None,
1785 on_file: OnFileCallback | None,
1786 trust_x_headers: bool = False,
1787 config: dict[Any, Any] = {},
1788) -> FormParser:
1789 """This function is a helper function to aid in creating a FormParser
1790 instances. Given a dictionary-like headers object, it will determine
1791 the correct information needed, instantiate a FormParser with the
1792 appropriate values and given callbacks, and then return the corresponding
1793 parser.
1794
1795 Args:
1796 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1797 on_field: Callback to call with each parsed field.
1798 on_file: Callback to call with each parsed file.
1799 trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file
1800 name from X-File-Name.
1801 config: Configuration variables to pass to the FormParser.
1802 """
1803 content_type: str | bytes | None = headers.get("Content-Type")
1804 if content_type is None:
1805 logging.getLogger(__name__).warning("No Content-Type header given")
1806 raise ValueError("No Content-Type header given!")
1807
1808 # Boundaries are optional (the FormParser will raise if one is needed
1809 # but not given).
1810 content_type, params = parse_options_header(content_type)
1811 boundary = params.get(b"boundary")
1812
1813 # We need content_type to be a string, not a bytes object.
1814 content_type = content_type.decode("latin-1")
1815
1816 # File names are optional.
1817 file_name = headers.get("X-File-Name")
1818
1819 # Instantiate a form parser.
1820 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config)
1821
1822 # Return our parser.
1823 return form_parser
1824
1825
1826def parse_form(
1827 headers: dict[str, bytes],
1828 input_stream: SupportsRead,
1829 on_field: OnFieldCallback | None,
1830 on_file: OnFileCallback | None,
1831 chunk_size: int = 1048576,
1832) -> None:
1833 """This function is useful if you just want to parse a request body,
1834 without too much work. Pass it a dictionary-like object of the request's
1835 headers, and a file-like object for the input stream, along with two
1836 callbacks that will get called whenever a field or file is parsed.
1837
1838 Args:
1839 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1840 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1841 on_field: Callback to call with each parsed field.
1842 on_file: Callback to call with each parsed file.
1843 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1844 Defaults to 1 MiB.
1845 """
1846 # Create our form parser.
1847 parser = create_form_parser(headers, on_field, on_file)
1848
1849 # Read chunks of 1MiB and write to the parser, but never read more than
1850 # the given Content-Length, if any.
1851 content_length: int | float | bytes | None = headers.get("Content-Length")
1852 if content_length is not None:
1853 content_length = int(content_length)
1854 else:
1855 content_length = float("inf")
1856 bytes_read = 0
1857
1858 while True:
1859 # Read only up to the Content-Length given.
1860 max_readable = int(min(content_length - bytes_read, chunk_size))
1861 buff = input_stream.read(max_readable)
1862
1863 # Write to the parser and update our length.
1864 parser.write(buff)
1865 bytes_read += len(buff)
1866
1867 # If we get a buffer that's smaller than the size requested, or if we
1868 # have read up to our content length, we're done.
1869 if len(buff) != max_readable or bytes_read == content_length:
1870 break
1871
1872 # Tell our parser that we're done writing data.
1873 parser.finalize()