1from __future__ import annotations
2
3import io
4import logging
5import os
6import shutil
7import sys
8import tempfile
9from email.message import Message
10from enum import IntEnum
11from io import BytesIO
12from numbers import Number
13from typing import TYPE_CHECKING, Any
14
15from .decoders import Base64Decoder, QuotedPrintableDecoder
16from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
17
18if TYPE_CHECKING: # pragma: no cover
19 from typing import Callable, Protocol, TypedDict
20
21 class QuerystringCallbacks(TypedDict, total=False):
22 on_field_start: Callable[[], None]
23 on_field_name: Callable[[bytes, int, int], None]
24 on_field_data: Callable[[bytes, int, int], None]
25 on_field_end: Callable[[], None]
26 on_end: Callable[[], None]
27
28 class OctetStreamCallbacks(TypedDict, total=False):
29 on_start: Callable[[], None]
30 on_data: Callable[[bytes, int, int], None]
31 on_end: Callable[[], None]
32
33 class MultipartCallbacks(TypedDict, total=False):
34 on_part_begin: Callable[[], None]
35 on_part_data: Callable[[bytes, int, int], None]
36 on_part_end: Callable[[], None]
37 on_header_begin: Callable[[], None]
38 on_header_field: Callable[[bytes, int, int], None]
39 on_header_value: Callable[[bytes, int, int], None]
40 on_header_end: Callable[[], None]
41 on_headers_finished: Callable[[], None]
42 on_end: Callable[[], None]
43
44 class FormParserConfig(TypedDict):
45 UPLOAD_DIR: str | None
46 UPLOAD_KEEP_FILENAME: bool
47 UPLOAD_KEEP_EXTENSIONS: bool
48 UPLOAD_ERROR_ON_BAD_CTE: bool
49 MAX_MEMORY_FILE_SIZE: int
50 MAX_BODY_SIZE: float
51
52 class FileConfig(TypedDict, total=False):
53 UPLOAD_DIR: str | bytes | None
54 UPLOAD_DELETE_TMP: bool
55 UPLOAD_KEEP_FILENAME: bool
56 UPLOAD_KEEP_EXTENSIONS: bool
57 MAX_MEMORY_FILE_SIZE: int
58
59 class _FormProtocol(Protocol):
60 def write(self, data: bytes) -> int:
61 ...
62
63 def finalize(self) -> None:
64 ...
65
66 def close(self) -> None:
67 ...
68
69 class FieldProtocol(_FormProtocol, Protocol):
70 def __init__(self, name: bytes) -> None:
71 ...
72
73 def set_none(self) -> None:
74 ...
75
76 class FileProtocol(_FormProtocol, Protocol):
77 def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None:
78 ...
79
80 OnFieldCallback = Callable[[FieldProtocol], None]
81 OnFileCallback = Callable[[FileProtocol], None]
82
83
84# Unique missing object.
85_missing = object()
86
87
88class QuerystringState(IntEnum):
89 """Querystring parser states.
90
91 These are used to keep track of the state of the parser, and are used to determine
92 what to do when new data is encountered.
93 """
94
95 BEFORE_FIELD = 0
96 FIELD_NAME = 1
97 FIELD_DATA = 2
98
99
100class MultipartState(IntEnum):
101 """Multipart parser states.
102
103 These are used to keep track of the state of the parser, and are used to determine
104 what to do when new data is encountered.
105 """
106
107 START = 0
108 START_BOUNDARY = 1
109 HEADER_FIELD_START = 2
110 HEADER_FIELD = 3
111 HEADER_VALUE_START = 4
112 HEADER_VALUE = 5
113 HEADER_VALUE_ALMOST_DONE = 6
114 HEADERS_ALMOST_DONE = 7
115 PART_DATA_START = 8
116 PART_DATA = 9
117 PART_DATA_END = 10
118 END = 11
119
120
121# Flags for the multipart parser.
122FLAG_PART_BOUNDARY = 1
123FLAG_LAST_BOUNDARY = 2
124
125# Get constants. Since iterating over a str on Python 2 gives you a 1-length
126# string, but iterating over a bytes object on Python 3 gives you an integer,
127# we need to save these constants.
128CR = b"\r"[0]
129LF = b"\n"[0]
130COLON = b":"[0]
131SPACE = b" "[0]
132HYPHEN = b"-"[0]
133AMPERSAND = b"&"[0]
134SEMICOLON = b";"[0]
135LOWER_A = b"a"[0]
136LOWER_Z = b"z"[0]
137NULL = b"\x00"[0]
138
139# Mask for ASCII characters that can be http tokens.
140# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
141# and these: !#$%&'*+-.^_`|~
142TOKEN_CHARS_SET = frozenset(
143 b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
144 b"abcdefghijklmnopqrstuvwxyz"
145 b"0123456789"
146 b"!#$%&'*+-.^_`|~")
147
148
149def ord_char(c: int) -> int:
150 return c
151
152
153def join_bytes(b: bytes) -> bytes:
154 return bytes(list(b))
155
156
157def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
158 """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
159 # Uses email.message.Message to parse the header as described in PEP 594.
160 # Ref: https://peps.python.org/pep-0594/#cgi
161 if not value:
162 return (b"", {})
163
164 # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1.
165 if isinstance(value, bytes): # pragma: no cover
166 value = value.decode("latin-1")
167
168 # For types
169 assert isinstance(value, str), "Value should be a string by now"
170
171 # If we have no options, return the string as-is.
172 if ";" not in value:
173 return (value.lower().strip().encode("latin-1"), {})
174
175 # Split at the first semicolon, to get our value and then options.
176 # ctype, rest = value.split(b';', 1)
177 message = Message()
178 message["content-type"] = value
179 params = message.get_params()
180 # If there were no parameters, this would have already returned above
181 assert params, "At least the content type value should be present"
182 ctype = params.pop(0)[0].encode("latin-1")
183 options: dict[bytes, bytes] = {}
184 for param in params:
185 key, value = param
186 # If the value returned from get_params() is a 3-tuple, the last
187 # element corresponds to the value.
188 # See: https://docs.python.org/3/library/email.compat32-message.html
189 if isinstance(value, tuple):
190 value = value[-1]
191 # If the value is a filename, we need to fix a bug on IE6 that sends
192 # the full file path instead of the filename.
193 if key == "filename":
194 if value[1:3] == ":\\" or value[:2] == "\\\\":
195 value = value.split("\\")[-1]
196 options[key.encode("latin-1")] = value.encode("latin-1")
197 return ctype, options
198
199
200class Field:
201 """A Field object represents a (parsed) form field. It represents a single
202 field with a corresponding name and value.
203
204 The name that a :class:`Field` will be instantiated with is the same name
205 that would be found in the following HTML::
206
207 <input name="name_goes_here" type="text"/>
208
209 This class defines two methods, :meth:`on_data` and :meth:`on_end`, that
210 will be called when data is written to the Field, and when the Field is
211 finalized, respectively.
212
213 Args:
214 name: The name of the form field.
215 """
216
217 def __init__(self, name: bytes) -> None:
218 self._name = name
219 self._value: list[bytes] = []
220
221 # We cache the joined version of _value for speed.
222 self._cache = _missing
223
224 @classmethod
225 def from_value(cls, name: bytes, value: bytes | None) -> Field:
226 """Create an instance of a :class:`Field`, and set the corresponding
227 value - either None or an actual value. This method will also
228 finalize the Field itself.
229
230 Args:
231 name: the name of the form field.
232 value: the value of the form field - either a bytestring or None.
233
234 Returns:
235 A new instance of a [`Field`][multipart.Field].
236 """
237
238 f = cls(name)
239 if value is None:
240 f.set_none()
241 else:
242 f.write(value)
243 f.finalize()
244 return f
245
246 def write(self, data: bytes) -> int:
247 """Write some data into the form field.
248
249 Args:
250 data: The data to write to the field.
251
252 Returns:
253 The number of bytes written.
254 """
255 return self.on_data(data)
256
257 def on_data(self, data: bytes) -> int:
258 """This method is a callback that will be called whenever data is
259 written to the Field.
260
261 Args:
262 data: The data to write to the field.
263
264 Returns:
265 The number of bytes written.
266 """
267 self._value.append(data)
268 self._cache = _missing
269 return len(data)
270
271 def on_end(self) -> None:
272 """This method is called whenever the Field is finalized."""
273 if self._cache is _missing:
274 self._cache = b"".join(self._value)
275
276 def finalize(self) -> None:
277 """Finalize the form field."""
278 self.on_end()
279
280 def close(self) -> None:
281 """Close the Field object. This will free any underlying cache."""
282 # Free our value array.
283 if self._cache is _missing:
284 self._cache = b"".join(self._value)
285
286 del self._value
287
288 def set_none(self) -> None:
289 """Some fields in a querystring can possibly have a value of None - for
290 example, the string "foo&bar=&baz=asdf" will have a field with the
291 name "foo" and value None, one with name "bar" and value "", and one
292 with name "baz" and value "asdf". Since the write() interface doesn't
293 support writing None, this function will set the field value to None.
294 """
295 self._cache = None
296
297 @property
298 def field_name(self) -> bytes:
299 """This property returns the name of the field."""
300 return self._name
301
302 @property
303 def value(self) -> bytes | None:
304 """This property returns the value of the form field."""
305 if self._cache is _missing:
306 self._cache = b"".join(self._value)
307
308 return self._cache
309
310 def __eq__(self, other: object) -> bool:
311 if isinstance(other, Field):
312 return self.field_name == other.field_name and self.value == other.value
313 else:
314 return NotImplemented
315
316 def __repr__(self) -> str:
317 if self.value is not None and len(self.value) > 97:
318 # We get the repr, and then insert three dots before the final
319 # quote.
320 v = repr(self.value[:97])[:-1] + "...'"
321 else:
322 v = repr(self.value)
323
324 return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v)
325
326
327class File:
328 """This class represents an uploaded file. It handles writing file data to
329 either an in-memory file or a temporary file on-disk, if the optional
330 threshold is passed.
331
332 There are some options that can be passed to the File to change behavior
333 of the class. Valid options are as follows:
334
335 | Name | Type | Default | Description |
336 |-----------------------|-------|---------|-------------|
337 | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. |
338 | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file |
339 | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. |
340 | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. |
341 | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. |
342
343 Args:
344 file_name: The name of the file that this [`File`][multipart.File] represents.
345 field_name: The name of the form field that this file was uploaded with. This can be None, if, for example,
346 the file was uploaded with Content-Type application/octet-stream.
347 config: The configuration for this File. See above for valid configuration keys and their corresponding values.
348 """ # noqa: E501
349
350 def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
351 # Save configuration, set other variables default.
352 self.logger = logging.getLogger(__name__)
353 self._config = config
354 self._in_memory = True
355 self._bytes_written = 0
356 self._fileobj = BytesIO()
357
358 # Save the provided field/file name.
359 self._field_name = field_name
360 self._file_name = file_name
361
362 # Our actual file name is None by default, since, depending on our
363 # config, we may not actually use the provided name.
364 self._actual_file_name = None
365
366 # Split the extension from the filename.
367 if file_name is not None:
368 base, ext = os.path.splitext(file_name)
369 self._file_base = base
370 self._ext = ext
371
372 @property
373 def field_name(self) -> bytes | None:
374 """The form field associated with this file. May be None if there isn't
375 one, for example when we have an application/octet-stream upload.
376 """
377 return self._field_name
378
379 @property
380 def file_name(self) -> bytes | None:
381 """The file name given in the upload request."""
382 return self._file_name
383
384 @property
385 def actual_file_name(self):
386 """The file name that this file is saved as. Will be None if it's not
387 currently saved on disk.
388 """
389 return self._actual_file_name
390
391 @property
392 def file_object(self):
393 """The file object that we're currently writing to. Note that this
394 will either be an instance of a :class:`io.BytesIO`, or a regular file
395 object.
396 """
397 return self._fileobj
398
399 @property
400 def size(self) -> int:
401 """The total size of this file, counted as the number of bytes that
402 currently have been written to the file.
403 """
404 return self._bytes_written
405
406 @property
407 def in_memory(self) -> bool:
408 """A boolean representing whether or not this file object is currently
409 stored in-memory or on-disk.
410 """
411 return self._in_memory
412
413 def flush_to_disk(self) -> None:
414 """If the file is already on-disk, do nothing. Otherwise, copy from
415 the in-memory buffer to a disk file, and then reassign our internal
416 file object to this new disk file.
417
418 Note that if you attempt to flush a file that is already on-disk, a
419 warning will be logged to this module's logger.
420 """
421 if not self._in_memory:
422 self.logger.warning("Trying to flush to disk when we're not in memory")
423 return
424
425 # Go back to the start of our file.
426 self._fileobj.seek(0)
427
428 # Open a new file.
429 new_file = self._get_disk_file()
430
431 # Copy the file objects.
432 shutil.copyfileobj(self._fileobj, new_file)
433
434 # Seek to the new position in our new file.
435 new_file.seek(self._bytes_written)
436
437 # Reassign the fileobject.
438 old_fileobj = self._fileobj
439 self._fileobj = new_file
440
441 # We're no longer in memory.
442 self._in_memory = False
443
444 # Close the old file object.
445 old_fileobj.close()
446
447 def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[bytes]: # type: ignore[reportPrivateUsage]
448 """This function is responsible for getting a file object on-disk for us."""
449 self.logger.info("Opening a file on disk")
450
451 file_dir = self._config.get("UPLOAD_DIR")
452 keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False)
453 keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False)
454 delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True)
455
456 # If we have a directory and are to keep the filename...
457 if file_dir is not None and keep_filename:
458 self.logger.info("Saving with filename in: %r", file_dir)
459
460 # Build our filename.
461 # TODO: what happens if we don't have a filename?
462 fname = self._file_base + self._ext if keep_extensions else self._file_base
463
464 path = os.path.join(file_dir, fname)
465 try:
466 self.logger.info("Opening file: %r", path)
467 tmp_file = open(path, "w+b")
468 except OSError:
469 tmp_file = None
470
471 self.logger.exception("Error opening temporary file")
472 raise FileError("Error opening temporary file: %r" % path)
473 else:
474 # Build options array.
475 # Note that on Python 3, tempfile doesn't support byte names. We
476 # encode our paths using the default filesystem encoding.
477 suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None
478
479 if file_dir is None:
480 dir = None
481 elif isinstance(file_dir, bytes):
482 dir = file_dir.decode(sys.getfilesystemencoding())
483 else:
484 dir = file_dir
485
486 # Create a temporary (named) file with the appropriate settings.
487 self.logger.info(
488 "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir}
489 )
490 try:
491 tmp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir)
492 except OSError:
493 self.logger.exception("Error creating named temporary file")
494 raise FileError("Error creating named temporary file")
495
496 fname = tmp_file.name
497
498 # Encode filename as bytes.
499 if isinstance(fname, str):
500 fname = fname.encode(sys.getfilesystemencoding())
501
502 self._actual_file_name = fname
503 return tmp_file
504
505 def write(self, data: bytes) -> int:
506 """Write some data to the File.
507
508 :param data: a bytestring
509 """
510 return self.on_data(data)
511
512 def on_data(self, data: bytes) -> int:
513 """This method is a callback that will be called whenever data is
514 written to the File.
515
516 Args:
517 data: The data to write to the file.
518
519 Returns:
520 The number of bytes written.
521 """
522 pos = self._fileobj.tell()
523 bwritten = self._fileobj.write(data)
524 # true file objects write returns None
525 if bwritten is None:
526 bwritten = self._fileobj.tell() - pos
527
528 # If the bytes written isn't the same as the length, just return.
529 if bwritten != len(data):
530 self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data))
531 return bwritten
532
533 # Keep track of how many bytes we've written.
534 self._bytes_written += bwritten
535
536 # If we're in-memory and are over our limit, we create a file.
537 max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE")
538 if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size):
539 self.logger.info("Flushing to disk")
540 self.flush_to_disk()
541
542 # Return the number of bytes written.
543 return bwritten
544
545 def on_end(self) -> None:
546 """This method is called whenever the Field is finalized."""
547 # Flush the underlying file object
548 self._fileobj.flush()
549
550 def finalize(self) -> None:
551 """Finalize the form file. This will not close the underlying file,
552 but simply signal that we are finished writing to the File.
553 """
554 self.on_end()
555
556 def close(self) -> None:
557 """Close the File object. This will actually close the underlying
558 file object (whether it's a :class:`io.BytesIO` or an actual file
559 object).
560 """
561 self._fileobj.close()
562
563 def __repr__(self) -> str:
564 return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name)
565
566
567class BaseParser:
568 """This class is the base class for all parsers. It contains the logic for
569 calling and adding callbacks.
570
571 A callback can be one of two different forms. "Notification callbacks" are
572 callbacks that are called when something happens - for example, when a new
573 part of a multipart message is encountered by the parser. "Data callbacks"
574 are called when we get some sort of data - for example, part of the body of
575 a multipart chunk. Notification callbacks are called with no parameters,
576 whereas data callbacks are called with three, as follows::
577
578 data_callback(data, start, end)
579
580 The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on
581 Python 3). "start" and "end" are integer indexes into the "data" string
582 that represent the data of interest. Thus, in a data callback, the slice
583 `data[start:end]` represents the data that the callback is "interested in".
584 The callback is not passed a copy of the data, since copying severely hurts
585 performance.
586 """
587
588 def __init__(self) -> None:
589 self.logger = logging.getLogger(__name__)
590
591 def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None):
592 """This function calls a provided callback with some data. If the
593 callback is not set, will do nothing.
594
595 Args:
596 name: The name of the callback to call (as a string).
597 data: Data to pass to the callback. If None, then it is assumed that the callback is a notification
598 callback, and no parameters are given.
599 end: An integer that is passed to the data callback.
600 start: An integer that is passed to the data callback.
601 """
602 name = "on_" + name
603 func = self.callbacks.get(name)
604 if func is None:
605 return
606
607 # Depending on whether we're given a buffer...
608 if data is not None:
609 # Don't do anything if we have start == end.
610 if start is not None and start == end:
611 return
612
613 self.logger.debug("Calling %s with data[%d:%d]", name, start, end)
614 func(data, start, end)
615 else:
616 self.logger.debug("Calling %s with no data", name)
617 func()
618
619 def set_callback(self, name: str, new_func: Callable[..., Any] | None) -> None:
620 """Update the function for a callback. Removes from the callbacks dict
621 if new_func is None.
622
623 :param name: The name of the callback to call (as a string).
624
625 :param new_func: The new function for the callback. If None, then the
626 callback will be removed (with no error if it does not
627 exist).
628 """
629 if new_func is None:
630 self.callbacks.pop("on_" + name, None)
631 else:
632 self.callbacks["on_" + name] = new_func
633
634 def close(self):
635 pass # pragma: no cover
636
637 def finalize(self):
638 pass # pragma: no cover
639
640 def __repr__(self):
641 return "%s()" % self.__class__.__name__
642
643
644class OctetStreamParser(BaseParser):
645 """This parser parses an octet-stream request body and calls callbacks when
646 incoming data is received. Callbacks are as follows:
647
648 | Callback Name | Parameters | Description |
649 |----------------|-----------------|-----------------------------------------------------|
650 | on_start | None | Called when the first data is parsed. |
651 | on_data | data, start, end| Called for each data chunk that is parsed. |
652 | on_end | None | Called when the parser is finished parsing all data.|
653
654 Args:
655 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][multipart.BaseParser].
656 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
657 """
658
659 def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")):
660 super().__init__()
661 self.callbacks = callbacks
662 self._started = False
663
664 if not isinstance(max_size, Number) or max_size < 1:
665 raise ValueError("max_size must be a positive number, not %r" % max_size)
666 self.max_size = max_size
667 self._current_size = 0
668
669 def write(self, data: bytes) -> int:
670 """Write some data to the parser, which will perform size verification,
671 and then pass the data to the underlying callback.
672
673 Args:
674 data: The data to write to the parser.
675
676 Returns:
677 The number of bytes written.
678 """
679 if not self._started:
680 self.callback("start")
681 self._started = True
682
683 # Truncate data length.
684 data_len = len(data)
685 if (self._current_size + data_len) > self.max_size:
686 # We truncate the length of data that we are to process.
687 new_size = int(self.max_size - self._current_size)
688 self.logger.warning(
689 "Current size is %d (max %d), so truncating data length from %d to %d",
690 self._current_size,
691 self.max_size,
692 data_len,
693 new_size,
694 )
695 data_len = new_size
696
697 # Increment size, then callback, in case there's an exception.
698 self._current_size += data_len
699 self.callback("data", data, 0, data_len)
700 return data_len
701
702 def finalize(self) -> None:
703 """Finalize this parser, which signals to that we are finished parsing,
704 and sends the on_end callback.
705 """
706 self.callback("end")
707
708 def __repr__(self) -> str:
709 return "%s()" % self.__class__.__name__
710
711
712class QuerystringParser(BaseParser):
713 """This is a streaming querystring parser. It will consume data, and call
714 the callbacks given when it has data.
715
716 | Callback Name | Parameters | Description |
717 |----------------|-----------------|-----------------------------------------------------|
718 | on_field_start | None | Called when a new field is encountered. |
719 | on_field_name | data, start, end| Called when a portion of a field's name is encountered. |
720 | on_field_data | data, start, end| Called when a portion of a field's data is encountered. |
721 | on_field_end | None | Called when the end of a field is encountered. |
722 | on_end | None | Called when the parser is finished parsing all data.|
723
724 Args:
725 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][multipart.BaseParser].
726 strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the
727 behavior of the parser changes as the following: if a field has a value with an equal sign
728 (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."),
729 it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered,
730 then a [`QuerystringParseError`][multipart.exceptions.QuerystringParseError] will be raised.
731 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
732 """ # noqa: E501
733
734 state: QuerystringState
735
736 def __init__(
737 self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf")
738 ) -> None:
739 super().__init__()
740 self.state = QuerystringState.BEFORE_FIELD
741 self._found_sep = False
742
743 self.callbacks = callbacks
744
745 # Max-size stuff
746 if not isinstance(max_size, Number) or max_size < 1:
747 raise ValueError("max_size must be a positive number, not %r" % max_size)
748 self.max_size = max_size
749 self._current_size = 0
750
751 # Should parsing be strict?
752 self.strict_parsing = strict_parsing
753
754 def write(self, data: bytes) -> int:
755 """Write some data to the parser, which will perform size verification,
756 parse into either a field name or value, and then pass the
757 corresponding data to the underlying callback. If an error is
758 encountered while parsing, a QuerystringParseError will be raised. The
759 "offset" attribute of the raised exception will be set to the offset in
760 the input data chunk (NOT the overall stream) that caused the error.
761
762 Args:
763 data: The data to write to the parser.
764
765 Returns:
766 The number of bytes written.
767 """
768 # Handle sizing.
769 data_len = len(data)
770 if (self._current_size + data_len) > self.max_size:
771 # We truncate the length of data that we are to process.
772 new_size = int(self.max_size - self._current_size)
773 self.logger.warning(
774 "Current size is %d (max %d), so truncating data length from %d to %d",
775 self._current_size,
776 self.max_size,
777 data_len,
778 new_size,
779 )
780 data_len = new_size
781
782 l = 0
783 try:
784 l = self._internal_write(data, data_len)
785 finally:
786 self._current_size += l
787
788 return l
789
790 def _internal_write(self, data: bytes, length: int) -> int:
791 state = self.state
792 strict_parsing = self.strict_parsing
793 found_sep = self._found_sep
794
795 i = 0
796 while i < length:
797 ch = data[i]
798
799 # Depending on our state...
800 if state == QuerystringState.BEFORE_FIELD:
801 # If the 'found_sep' flag is set, we've already encountered
802 # and skipped a single separator. If so, we check our strict
803 # parsing flag and decide what to do. Otherwise, we haven't
804 # yet reached a separator, and thus, if we do, we need to skip
805 # it as it will be the boundary between fields that's supposed
806 # to be there.
807 if ch == AMPERSAND or ch == SEMICOLON:
808 if found_sep:
809 # If we're parsing strictly, we disallow blank chunks.
810 if strict_parsing:
811 e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i)
812 e.offset = i
813 raise e
814 else:
815 self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i)
816 else:
817 # This case is when we're skipping the (first)
818 # separator between fields, so we just set our flag
819 # and continue on.
820 found_sep = True
821 else:
822 # Emit a field-start event, and go to that state. Also,
823 # reset the "found_sep" flag, for the next time we get to
824 # this state.
825 self.callback("field_start")
826 i -= 1
827 state = QuerystringState.FIELD_NAME
828 found_sep = False
829
830 elif state == QuerystringState.FIELD_NAME:
831 # Try and find a separator - we ensure that, if we do, we only
832 # look for the equal sign before it.
833 sep_pos = data.find(b"&", i)
834 if sep_pos == -1:
835 sep_pos = data.find(b";", i)
836
837 # See if we can find an equals sign in the remaining data. If
838 # so, we can immediately emit the field name and jump to the
839 # data state.
840 if sep_pos != -1:
841 equals_pos = data.find(b"=", i, sep_pos)
842 else:
843 equals_pos = data.find(b"=", i)
844
845 if equals_pos != -1:
846 # Emit this name.
847 self.callback("field_name", data, i, equals_pos)
848
849 # Jump i to this position. Note that it will then have 1
850 # added to it below, which means the next iteration of this
851 # loop will inspect the character after the equals sign.
852 i = equals_pos
853 state = QuerystringState.FIELD_DATA
854 else:
855 # No equals sign found.
856 if not strict_parsing:
857 # See also comments in the QuerystringState.FIELD_DATA case below.
858 # If we found the separator, we emit the name and just
859 # end - there's no data callback at all (not even with
860 # a blank value).
861 if sep_pos != -1:
862 self.callback("field_name", data, i, sep_pos)
863 self.callback("field_end")
864
865 i = sep_pos - 1
866 state = QuerystringState.BEFORE_FIELD
867 else:
868 # Otherwise, no separator in this block, so the
869 # rest of this chunk must be a name.
870 self.callback("field_name", data, i, length)
871 i = length
872
873 else:
874 # We're parsing strictly. If we find a separator,
875 # this is an error - we require an equals sign.
876 if sep_pos != -1:
877 e = QuerystringParseError(
878 "When strict_parsing is True, we require an "
879 "equals sign in all field chunks. Did not "
880 "find one in the chunk that starts at %d" % (i,)
881 )
882 e.offset = i
883 raise e
884
885 # No separator in the rest of this chunk, so it's just
886 # a field name.
887 self.callback("field_name", data, i, length)
888 i = length
889
890 elif state == QuerystringState.FIELD_DATA:
891 # Try finding either an ampersand or a semicolon after this
892 # position.
893 sep_pos = data.find(b"&", i)
894 if sep_pos == -1:
895 sep_pos = data.find(b";", i)
896
897 # If we found it, callback this bit as data and then go back
898 # to expecting to find a field.
899 if sep_pos != -1:
900 self.callback("field_data", data, i, sep_pos)
901 self.callback("field_end")
902
903 # Note that we go to the separator, which brings us to the
904 # "before field" state. This allows us to properly emit
905 # "field_start" events only when we actually have data for
906 # a field of some sort.
907 i = sep_pos - 1
908 state = QuerystringState.BEFORE_FIELD
909
910 # Otherwise, emit the rest as data and finish.
911 else:
912 self.callback("field_data", data, i, length)
913 i = length
914
915 else: # pragma: no cover (error case)
916 msg = "Reached an unknown state %d at %d" % (state, i)
917 self.logger.warning(msg)
918 e = QuerystringParseError(msg)
919 e.offset = i
920 raise e
921
922 i += 1
923
924 self.state = state
925 self._found_sep = found_sep
926 return len(data)
927
928 def finalize(self) -> None:
929 """Finalize this parser, which signals to that we are finished parsing,
930 if we're still in the middle of a field, an on_field_end callback, and
931 then the on_end callback.
932 """
933 # If we're currently in the middle of a field, we finish it.
934 if self.state == QuerystringState.FIELD_DATA:
935 self.callback("field_end")
936 self.callback("end")
937
938 def __repr__(self) -> str:
939 return "{}(strict_parsing={!r}, max_size={!r})".format(
940 self.__class__.__name__, self.strict_parsing, self.max_size
941 )
942
943
944class MultipartParser(BaseParser):
945 """This class is a streaming multipart/form-data parser.
946
947 | Callback Name | Parameters | Description |
948 |--------------------|-----------------|-------------|
949 | on_part_begin | None | Called when a new part of the multipart message is encountered. |
950 | on_part_data | data, start, end| Called when a portion of a part's data is encountered. |
951 | on_part_end | None | Called when the end of a part is reached. |
952 | on_header_begin | None | Called when we've found a new header in a part of a multipart message |
953 | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). |
954 | on_header_value | data, start, end| Called when we get data for a header. |
955 | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. |
956 | on_headers_finished| None | Called when all headers are finished, and before the part data starts. |
957 | on_end | None | Called when the parser is finished parsing all data. |
958
959 Args:
960 boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
961 callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][multipart.BaseParser].
962 max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded.
963 """ # noqa: E501
964
965 def __init__(
966 self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
967 ) -> None:
968 # Initialize parser state.
969 super().__init__()
970 self.state = MultipartState.START
971 self.index = self.flags = 0
972
973 self.callbacks = callbacks
974
975 if not isinstance(max_size, Number) or max_size < 1:
976 raise ValueError("max_size must be a positive number, not %r" % max_size)
977 self.max_size = max_size
978 self._current_size = 0
979
980 # Setup marks. These are used to track the state of data received.
981 self.marks: dict[str, int] = {}
982
983 # TODO: Actually use this rather than the dumb version we currently use
984 # # Precompute the skip table for the Boyer-Moore-Horspool algorithm.
985 # skip = [len(boundary) for x in range(256)]
986 # for i in range(len(boundary) - 1):
987 # skip[ord_char(boundary[i])] = len(boundary) - i - 1
988 #
989 # # We use a tuple since it's a constant, and marginally faster.
990 # self.skip = tuple(skip)
991
992 # Save our boundary.
993 if isinstance(boundary, str): # pragma: no cover
994 boundary = boundary.encode("latin-1")
995 self.boundary = b"\r\n--" + boundary
996
997 # Get a set of characters that belong to our boundary.
998 self.boundary_chars = frozenset(self.boundary)
999
1000 # We also create a lookbehind list.
1001 # Note: the +8 is since we can have, at maximum, "\r\n--" + boundary +
1002 # "--\r\n" at the final boundary, and the length of '\r\n--' and
1003 # '--\r\n' is 8 bytes.
1004 self.lookbehind = [NULL for _ in range(len(boundary) + 8)]
1005
1006 def write(self, data: bytes) -> int:
1007 """Write some data to the parser, which will perform size verification,
1008 and then parse the data into the appropriate location (e.g. header,
1009 data, etc.), and pass this on to the underlying callback. If an error
1010 is encountered, a MultipartParseError will be raised. The "offset"
1011 attribute on the raised exception will be set to the offset of the byte
1012 in the input chunk that caused the error.
1013
1014 Args:
1015 data: The data to write to the parser.
1016
1017 Returns:
1018 The number of bytes written.
1019 """
1020 # Handle sizing.
1021 data_len = len(data)
1022 if (self._current_size + data_len) > self.max_size:
1023 # We truncate the length of data that we are to process.
1024 new_size = int(self.max_size - self._current_size)
1025 self.logger.warning(
1026 "Current size is %d (max %d), so truncating data length from %d to %d",
1027 self._current_size,
1028 self.max_size,
1029 data_len,
1030 new_size,
1031 )
1032 data_len = new_size
1033
1034 l = 0
1035 try:
1036 l = self._internal_write(data, data_len)
1037 finally:
1038 self._current_size += l
1039
1040 return l
1041
1042 def _internal_write(self, data: bytes, length: int) -> int:
1043 # Get values from locals.
1044 boundary = self.boundary
1045
1046 # Get our state, flags and index. These are persisted between calls to
1047 # this function.
1048 state = self.state
1049 index = self.index
1050 flags = self.flags
1051
1052 # Our index defaults to 0.
1053 i = 0
1054
1055 # Set a mark.
1056 def set_mark(name: str):
1057 self.marks[name] = i
1058
1059 # Remove a mark.
1060 def delete_mark(name: str, reset: bool = False) -> None:
1061 self.marks.pop(name, None)
1062
1063 # Helper function that makes calling a callback with data easier. The
1064 # 'remaining' parameter will callback from the marked value until the
1065 # end of the buffer, and reset the mark, instead of deleting it. This
1066 # is used at the end of the function to call our callbacks with any
1067 # remaining data in this chunk.
1068 def data_callback(name: str, remaining: bool = False) -> None:
1069 marked_index = self.marks.get(name)
1070 if marked_index is None:
1071 return
1072
1073 # If we're getting remaining data, we ignore the current i value
1074 # and just call with the remaining data.
1075 if remaining:
1076 self.callback(name, data, marked_index, length)
1077 self.marks[name] = 0
1078
1079 # Otherwise, we call it from the mark to the current byte we're
1080 # processing.
1081 else:
1082 self.callback(name, data, marked_index, i)
1083 self.marks.pop(name, None)
1084
1085 # For each byte...
1086 while i < length:
1087 c = data[i]
1088
1089 if state == MultipartState.START:
1090 # Skip leading newlines
1091 if c == CR or c == LF:
1092 i += 1
1093 self.logger.debug("Skipping leading CR/LF at %d", i)
1094 continue
1095
1096 # index is used as in index into our boundary. Set to 0.
1097 index = 0
1098
1099 # Move to the next state, but decrement i so that we re-process
1100 # this character.
1101 state = MultipartState.START_BOUNDARY
1102 i -= 1
1103
1104 elif state == MultipartState.START_BOUNDARY:
1105 # Check to ensure that the last 2 characters in our boundary
1106 # are CRLF.
1107 if index == len(boundary) - 2:
1108 if c != CR:
1109 # Error!
1110 msg = "Did not find CR at end of boundary (%d)" % (i,)
1111 self.logger.warning(msg)
1112 e = MultipartParseError(msg)
1113 e.offset = i
1114 raise e
1115
1116 index += 1
1117
1118 elif index == len(boundary) - 2 + 1:
1119 if c != LF:
1120 msg = "Did not find LF at end of boundary (%d)" % (i,)
1121 self.logger.warning(msg)
1122 e = MultipartParseError(msg)
1123 e.offset = i
1124 raise e
1125
1126 # The index is now used for indexing into our boundary.
1127 index = 0
1128
1129 # Callback for the start of a part.
1130 self.callback("part_begin")
1131
1132 # Move to the next character and state.
1133 state = MultipartState.HEADER_FIELD_START
1134
1135 else:
1136 # Check to ensure our boundary matches
1137 if c != boundary[index + 2]:
1138 msg = "Did not find boundary character %r at index " "%d" % (c, index + 2)
1139 self.logger.warning(msg)
1140 e = MultipartParseError(msg)
1141 e.offset = i
1142 raise e
1143
1144 # Increment index into boundary and continue.
1145 index += 1
1146
1147 elif state == MultipartState.HEADER_FIELD_START:
1148 # Mark the start of a header field here, reset the index, and
1149 # continue parsing our header field.
1150 index = 0
1151
1152 # Set a mark of our header field.
1153 set_mark("header_field")
1154
1155 # Notify that we're starting a header if the next character is
1156 # not a CR; a CR at the beginning of the header will cause us
1157 # to stop parsing headers in the MultipartState.HEADER_FIELD state,
1158 # below.
1159 if c != CR:
1160 self.callback("header_begin")
1161
1162 # Move to parsing header fields.
1163 state = MultipartState.HEADER_FIELD
1164 i -= 1
1165
1166 elif state == MultipartState.HEADER_FIELD:
1167 # If we've reached a CR at the beginning of a header, it means
1168 # that we've reached the second of 2 newlines, and so there are
1169 # no more headers to parse.
1170 if c == CR and index == 0:
1171 delete_mark("header_field")
1172 state = MultipartState.HEADERS_ALMOST_DONE
1173 i += 1
1174 continue
1175
1176 # Increment our index in the header.
1177 index += 1
1178
1179 # If we've reached a colon, we're done with this header.
1180 if c == COLON:
1181 # A 0-length header is an error.
1182 if index == 1:
1183 msg = "Found 0-length header at %d" % (i,)
1184 self.logger.warning(msg)
1185 e = MultipartParseError(msg)
1186 e.offset = i
1187 raise e
1188
1189 # Call our callback with the header field.
1190 data_callback("header_field")
1191
1192 # Move to parsing the header value.
1193 state = MultipartState.HEADER_VALUE_START
1194
1195 elif c not in TOKEN_CHARS_SET:
1196 msg = "Found invalid character %r in header at %d" % (c, i)
1197 self.logger.warning(msg)
1198 e = MultipartParseError(msg)
1199 e.offset = i
1200 raise e
1201
1202 elif state == MultipartState.HEADER_VALUE_START:
1203 # Skip leading spaces.
1204 if c == SPACE:
1205 i += 1
1206 continue
1207
1208 # Mark the start of the header value.
1209 set_mark("header_value")
1210
1211 # Move to the header-value state, reprocessing this character.
1212 state = MultipartState.HEADER_VALUE
1213 i -= 1
1214
1215 elif state == MultipartState.HEADER_VALUE:
1216 # If we've got a CR, we're nearly done our headers. Otherwise,
1217 # we do nothing and just move past this character.
1218 if c == CR:
1219 data_callback("header_value")
1220 self.callback("header_end")
1221 state = MultipartState.HEADER_VALUE_ALMOST_DONE
1222
1223 elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
1224 # The last character should be a LF. If not, it's an error.
1225 if c != LF:
1226 msg = "Did not find LF character at end of header " "(found %r)" % (c,)
1227 self.logger.warning(msg)
1228 e = MultipartParseError(msg)
1229 e.offset = i
1230 raise e
1231
1232 # Move back to the start of another header. Note that if that
1233 # state detects ANOTHER newline, it'll trigger the end of our
1234 # headers.
1235 state = MultipartState.HEADER_FIELD_START
1236
1237 elif state == MultipartState.HEADERS_ALMOST_DONE:
1238 # We're almost done our headers. This is reached when we parse
1239 # a CR at the beginning of a header, so our next character
1240 # should be a LF, or it's an error.
1241 if c != LF:
1242 msg = f"Did not find LF at end of headers (found {c!r})"
1243 self.logger.warning(msg)
1244 e = MultipartParseError(msg)
1245 e.offset = i
1246 raise e
1247
1248 self.callback("headers_finished")
1249 state = MultipartState.PART_DATA_START
1250
1251 elif state == MultipartState.PART_DATA_START:
1252 # Mark the start of our part data.
1253 set_mark("part_data")
1254
1255 # Start processing part data, including this character.
1256 state = MultipartState.PART_DATA
1257 i -= 1
1258
1259 elif state == MultipartState.PART_DATA:
1260 # We're processing our part data right now. During this, we
1261 # need to efficiently search for our boundary, since any data
1262 # on any number of lines can be a part of the current data.
1263 # We use the Boyer-Moore-Horspool algorithm to efficiently
1264 # search through the remainder of the buffer looking for our
1265 # boundary.
1266
1267 # Save the current value of our index. We use this in case we
1268 # find part of a boundary, but it doesn't match fully.
1269 prev_index = index
1270
1271 # Set up variables.
1272 boundary_length = len(boundary)
1273 boundary_end = boundary_length - 1
1274 data_length = length
1275 boundary_chars = self.boundary_chars
1276
1277 # If our index is 0, we're starting a new part, so start our
1278 # search.
1279 if index == 0:
1280 # Search forward until we either hit the end of our buffer,
1281 # or reach a character that's in our boundary.
1282 i += boundary_end
1283 while i < data_length - 1 and data[i] not in boundary_chars:
1284 i += boundary_length
1285
1286 # Reset i back the length of our boundary, which is the
1287 # earliest possible location that could be our match (i.e.
1288 # if we've just broken out of our loop since we saw the
1289 # last character in our boundary)
1290 i -= boundary_end
1291 c = data[i]
1292
1293 # Now, we have a couple of cases here. If our index is before
1294 # the end of the boundary...
1295 if index < boundary_length:
1296 # If the character matches...
1297 if boundary[index] == c:
1298 # If we found a match for our boundary, we send the
1299 # existing data.
1300 if index == 0:
1301 data_callback("part_data")
1302
1303 # The current character matches, so continue!
1304 index += 1
1305 else:
1306 index = 0
1307
1308 # Our index is equal to the length of our boundary!
1309 elif index == boundary_length:
1310 # First we increment it.
1311 index += 1
1312
1313 # Now, if we've reached a newline, we need to set this as
1314 # the potential end of our boundary.
1315 if c == CR:
1316 flags |= FLAG_PART_BOUNDARY
1317
1318 # Otherwise, if this is a hyphen, we might be at the last
1319 # of all boundaries.
1320 elif c == HYPHEN:
1321 flags |= FLAG_LAST_BOUNDARY
1322
1323 # Otherwise, we reset our index, since this isn't either a
1324 # newline or a hyphen.
1325 else:
1326 index = 0
1327
1328 # Our index is right after the part boundary, which should be
1329 # a LF.
1330 elif index == boundary_length + 1:
1331 # If we're at a part boundary (i.e. we've seen a CR
1332 # character already)...
1333 if flags & FLAG_PART_BOUNDARY:
1334 # We need a LF character next.
1335 if c == LF:
1336 # Unset the part boundary flag.
1337 flags &= ~FLAG_PART_BOUNDARY
1338
1339 # Callback indicating that we've reached the end of
1340 # a part, and are starting a new one.
1341 self.callback("part_end")
1342 self.callback("part_begin")
1343
1344 # Move to parsing new headers.
1345 index = 0
1346 state = MultipartState.HEADER_FIELD_START
1347 i += 1
1348 continue
1349
1350 # We didn't find an LF character, so no match. Reset
1351 # our index and clear our flag.
1352 index = 0
1353 flags &= ~FLAG_PART_BOUNDARY
1354
1355 # Otherwise, if we're at the last boundary (i.e. we've
1356 # seen a hyphen already)...
1357 elif flags & FLAG_LAST_BOUNDARY:
1358 # We need a second hyphen here.
1359 if c == HYPHEN:
1360 # Callback to end the current part, and then the
1361 # message.
1362 self.callback("part_end")
1363 self.callback("end")
1364 state = MultipartState.END
1365 else:
1366 # No match, so reset index.
1367 index = 0
1368
1369 # If we have an index, we need to keep this byte for later, in
1370 # case we can't match the full boundary.
1371 if index > 0:
1372 self.lookbehind[index - 1] = c
1373
1374 # Otherwise, our index is 0. If the previous index is not, it
1375 # means we reset something, and we need to take the data we
1376 # thought was part of our boundary and send it along as actual
1377 # data.
1378 elif prev_index > 0:
1379 # Callback to write the saved data.
1380 lb_data = join_bytes(self.lookbehind)
1381 self.callback("part_data", lb_data, 0, prev_index)
1382
1383 # Overwrite our previous index.
1384 prev_index = 0
1385
1386 # Re-set our mark for part data.
1387 set_mark("part_data")
1388
1389 # Re-consider the current character, since this could be
1390 # the start of the boundary itself.
1391 i -= 1
1392
1393 elif state == MultipartState.END:
1394 # Do nothing and just consume a byte in the end state.
1395 if c not in (CR, LF):
1396 self.logger.warning("Consuming a byte '0x%x' in the end state", c)
1397
1398 else: # pragma: no cover (error case)
1399 # We got into a strange state somehow! Just stop processing.
1400 msg = "Reached an unknown state %d at %d" % (state, i)
1401 self.logger.warning(msg)
1402 e = MultipartParseError(msg)
1403 e.offset = i
1404 raise e
1405
1406 # Move to the next byte.
1407 i += 1
1408
1409 # We call our callbacks with any remaining data. Note that we pass
1410 # the 'remaining' flag, which sets the mark back to 0 instead of
1411 # deleting it, if it's found. This is because, if the mark is found
1412 # at this point, we assume that there's data for one of these things
1413 # that has been parsed, but not yet emitted. And, as such, it implies
1414 # that we haven't yet reached the end of this 'thing'. So, by setting
1415 # the mark to 0, we cause any data callbacks that take place in future
1416 # calls to this function to start from the beginning of that buffer.
1417 data_callback("header_field", True)
1418 data_callback("header_value", True)
1419 data_callback("part_data", True)
1420
1421 # Save values to locals.
1422 self.state = state
1423 self.index = index
1424 self.flags = flags
1425
1426 # Return our data length to indicate no errors, and that we processed
1427 # all of it.
1428 return length
1429
1430 def finalize(self) -> None:
1431 """Finalize this parser, which signals to that we are finished parsing.
1432
1433 Note: It does not currently, but in the future, it will verify that we
1434 are in the final state of the parser (i.e. the end of the multipart
1435 message is well-formed), and, if not, throw an error.
1436 """
1437 # TODO: verify that we're in the state MultipartState.END, otherwise throw an
1438 # error or otherwise state that we're not finished parsing.
1439 pass
1440
1441 def __repr__(self) -> str:
1442 return f"{self.__class__.__name__}(boundary={self.boundary!r})"
1443
1444
1445class FormParser:
1446 """This class is the all-in-one form parser. Given all the information
1447 necessary to parse a form, it will instantiate the correct parser, create
1448 the proper :class:`Field` and :class:`File` classes to store the data that
1449 is parsed, and call the two given callbacks with each field and file as
1450 they become available.
1451
1452 Args:
1453 content_type: The Content-Type of the incoming request. This is used to select the appropriate parser.
1454 on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters.
1455 on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters.
1456 on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None.
1457 boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given
1458 in the Content-Type header, as a bytestring.
1459 file_name: If the request is of type application/octet-stream, then the body of the request will not contain any
1460 information about the uploaded file. In such cases, you can provide the file name of the uploaded file
1461 manually.
1462 FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class
1463 if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and
1464 it must provide the following functions::
1465 - file_instance.write(data)
1466 - file_instance.finalize()
1467 - file_instance.close()
1468 FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own
1469 class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it
1470 must provide the following functions::
1471 - field_instance.write(data)
1472 - field_instance.finalize()
1473 - field_instance.close()
1474 - field_instance.set_none()
1475 config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value,
1476 and then any keys present in this dictionary will overwrite the default values.
1477 """
1478
1479 #: This is the default configuration for our form parser.
1480 #: Note: all file sizes should be in bytes.
1481 DEFAULT_CONFIG: FormParserConfig = {
1482 "MAX_BODY_SIZE": float("inf"),
1483 "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
1484 "UPLOAD_DIR": None,
1485 "UPLOAD_KEEP_FILENAME": False,
1486 "UPLOAD_KEEP_EXTENSIONS": False,
1487 # Error on invalid Content-Transfer-Encoding?
1488 "UPLOAD_ERROR_ON_BAD_CTE": False,
1489 }
1490
1491 def __init__(
1492 self,
1493 content_type: str,
1494 on_field: OnFieldCallback,
1495 on_file: OnFileCallback,
1496 on_end: Callable[[], None] | None = None,
1497 boundary: bytes | str | None = None,
1498 file_name: bytes | None = None,
1499 FileClass: type[FileProtocol] = File,
1500 FieldClass: type[FieldProtocol] = Field,
1501 config: dict[Any, Any] = {},
1502 ) -> None:
1503 self.logger = logging.getLogger(__name__)
1504
1505 # Save variables.
1506 self.content_type = content_type
1507 self.boundary = boundary
1508 self.bytes_received = 0
1509 self.parser = None
1510
1511 # Save callbacks.
1512 self.on_field = on_field
1513 self.on_file = on_file
1514 self.on_end = on_end
1515
1516 # Save classes.
1517 self.FileClass = File
1518 self.FieldClass = Field
1519
1520 # Set configuration options.
1521 self.config = self.DEFAULT_CONFIG.copy()
1522 self.config.update(config)
1523
1524 # Depending on the Content-Type, we instantiate the correct parser.
1525 if content_type == "application/octet-stream":
1526 file: FileProtocol = None # type: ignore
1527
1528 def on_start() -> None:
1529 nonlocal file
1530 file = FileClass(file_name, None, config=self.config)
1531
1532 def on_data(data: bytes, start: int, end: int) -> None:
1533 nonlocal file
1534 file.write(data[start:end])
1535
1536 def _on_end() -> None:
1537 nonlocal file
1538 # Finalize the file itself.
1539 file.finalize()
1540
1541 # Call our callback.
1542 on_file(file)
1543
1544 # Call the on-end callback.
1545 if self.on_end is not None:
1546 self.on_end()
1547
1548 # Instantiate an octet-stream parser
1549 parser = OctetStreamParser(
1550 callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
1551 max_size=self.config["MAX_BODY_SIZE"],
1552 )
1553
1554 elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
1555 name_buffer: list[bytes] = []
1556
1557 f: FieldProtocol = None # type: ignore
1558
1559 def on_field_start() -> None:
1560 pass
1561
1562 def on_field_name(data: bytes, start: int, end: int) -> None:
1563 name_buffer.append(data[start:end])
1564
1565 def on_field_data(data: bytes, start: int, end: int) -> None:
1566 nonlocal f
1567 if f is None:
1568 f = FieldClass(b"".join(name_buffer))
1569 del name_buffer[:]
1570 f.write(data[start:end])
1571
1572 def on_field_end() -> None:
1573 nonlocal f
1574 # Finalize and call callback.
1575 if f is None:
1576 # If we get here, it's because there was no field data.
1577 # We create a field, set it to None, and then continue.
1578 f = FieldClass(b"".join(name_buffer))
1579 del name_buffer[:]
1580 f.set_none()
1581
1582 f.finalize()
1583 on_field(f)
1584 f = None
1585
1586 def _on_end() -> None:
1587 if self.on_end is not None:
1588 self.on_end()
1589
1590 # Instantiate parser.
1591 parser = QuerystringParser(
1592 callbacks={
1593 "on_field_start": on_field_start,
1594 "on_field_name": on_field_name,
1595 "on_field_data": on_field_data,
1596 "on_field_end": on_field_end,
1597 "on_end": _on_end,
1598 },
1599 max_size=self.config["MAX_BODY_SIZE"],
1600 )
1601
1602 elif content_type == "multipart/form-data":
1603 if boundary is None:
1604 self.logger.error("No boundary given")
1605 raise FormParserError("No boundary given")
1606
1607 header_name: list[bytes] = []
1608 header_value: list[bytes] = []
1609 headers = {}
1610
1611 f: FileProtocol | FieldProtocol | None = None
1612 writer = None
1613 is_file = False
1614
1615 def on_part_begin():
1616 # Reset headers in case this isn't the first part.
1617 nonlocal headers
1618 headers = {}
1619
1620 def on_part_data(data: bytes, start: int, end: int) -> None:
1621 nonlocal writer
1622 bytes_processed = writer.write(data[start:end])
1623 # TODO: check for error here.
1624 return bytes_processed
1625
1626 def on_part_end() -> None:
1627 nonlocal f, is_file
1628 f.finalize()
1629 if is_file:
1630 on_file(f)
1631 else:
1632 on_field(f)
1633
1634 def on_header_field(data: bytes, start: int, end: int) -> None:
1635 header_name.append(data[start:end])
1636
1637 def on_header_value(data: bytes, start: int, end: int) -> None:
1638 header_value.append(data[start:end])
1639
1640 def on_header_end() -> None:
1641 headers[b"".join(header_name)] = b"".join(header_value)
1642 del header_name[:]
1643 del header_value[:]
1644
1645 def on_headers_finished() -> None:
1646 nonlocal is_file, f, writer
1647 # Reset the 'is file' flag.
1648 is_file = False
1649
1650 # Parse the content-disposition header.
1651 # TODO: handle mixed case
1652 content_disp = headers.get(b"Content-Disposition")
1653 disp, options = parse_options_header(content_disp)
1654
1655 # Get the field and filename.
1656 field_name = options.get(b"name")
1657 file_name = options.get(b"filename")
1658 # TODO: check for errors
1659
1660 # Create the proper class.
1661 if file_name is None:
1662 f = FieldClass(field_name)
1663 else:
1664 f = FileClass(file_name, field_name, config=self.config)
1665 is_file = True
1666
1667 # Parse the given Content-Transfer-Encoding to determine what
1668 # we need to do with the incoming data.
1669 # TODO: check that we properly handle 8bit / 7bit encoding.
1670 transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")
1671
1672 if transfer_encoding in (b"binary", b"8bit", b"7bit"):
1673 writer = f
1674
1675 elif transfer_encoding == b"base64":
1676 writer = Base64Decoder(f)
1677
1678 elif transfer_encoding == b"quoted-printable":
1679 writer = QuotedPrintableDecoder(f)
1680
1681 else:
1682 self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
1683 if self.config["UPLOAD_ERROR_ON_BAD_CTE"]:
1684 raise FormParserError('Unknown Content-Transfer-Encoding "{}"'.format(transfer_encoding))
1685 else:
1686 # If we aren't erroring, then we just treat this as an
1687 # unencoded Content-Transfer-Encoding.
1688 writer = f
1689
1690 def _on_end() -> None:
1691 nonlocal writer
1692 writer.finalize()
1693 if self.on_end is not None:
1694 self.on_end()
1695
1696 # Instantiate a multipart parser.
1697 parser = MultipartParser(
1698 boundary,
1699 callbacks={
1700 "on_part_begin": on_part_begin,
1701 "on_part_data": on_part_data,
1702 "on_part_end": on_part_end,
1703 "on_header_field": on_header_field,
1704 "on_header_value": on_header_value,
1705 "on_header_end": on_header_end,
1706 "on_headers_finished": on_headers_finished,
1707 "on_end": _on_end,
1708 },
1709 max_size=self.config["MAX_BODY_SIZE"],
1710 )
1711
1712 else:
1713 self.logger.warning("Unknown Content-Type: %r", content_type)
1714 raise FormParserError("Unknown Content-Type: {}".format(content_type))
1715
1716 self.parser = parser
1717
1718 def write(self, data: bytes) -> int:
1719 """Write some data. The parser will forward this to the appropriate
1720 underlying parser.
1721
1722 Args:
1723 data: The data to write.
1724
1725 Returns:
1726 The number of bytes processed.
1727 """
1728 self.bytes_received += len(data)
1729 # TODO: check the parser's return value for errors?
1730 return self.parser.write(data)
1731
1732 def finalize(self) -> None:
1733 """Finalize the parser."""
1734 if self.parser is not None and hasattr(self.parser, "finalize"):
1735 self.parser.finalize()
1736
1737 def close(self) -> None:
1738 """Close the parser."""
1739 if self.parser is not None and hasattr(self.parser, "close"):
1740 self.parser.close()
1741
1742 def __repr__(self) -> str:
1743 return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)
1744
1745
1746def create_form_parser(
1747 headers: dict[str, bytes],
1748 on_field: OnFieldCallback,
1749 on_file: OnFileCallback,
1750 trust_x_headers: bool = False,
1751 config: dict[Any, Any] = {},
1752) -> FormParser:
1753 """This function is a helper function to aid in creating a FormParser
1754 instances. Given a dictionary-like headers object, it will determine
1755 the correct information needed, instantiate a FormParser with the
1756 appropriate values and given callbacks, and then return the corresponding
1757 parser.
1758
1759 Args:
1760 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1761 on_field: Callback to call with each parsed field.
1762 on_file: Callback to call with each parsed file.
1763 trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file
1764 name from X-File-Name.
1765 config: Configuration variables to pass to the FormParser.
1766 """
1767 content_type = headers.get("Content-Type")
1768 if content_type is None:
1769 logging.getLogger(__name__).warning("No Content-Type header given")
1770 raise ValueError("No Content-Type header given!")
1771
1772 # Boundaries are optional (the FormParser will raise if one is needed
1773 # but not given).
1774 content_type, params = parse_options_header(content_type)
1775 boundary = params.get(b"boundary")
1776
1777 # We need content_type to be a string, not a bytes object.
1778 content_type = content_type.decode("latin-1")
1779
1780 # File names are optional.
1781 file_name = headers.get("X-File-Name")
1782
1783 # Instantiate a form parser.
1784 form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config)
1785
1786 # Return our parser.
1787 return form_parser
1788
1789
1790def parse_form(
1791 headers: dict[str, bytes],
1792 input_stream: io.FileIO,
1793 on_field: OnFieldCallback,
1794 on_file: OnFileCallback,
1795 chunk_size: int = 1048576,
1796) -> None:
1797 """This function is useful if you just want to parse a request body,
1798 without too much work. Pass it a dictionary-like object of the request's
1799 headers, and a file-like object for the input stream, along with two
1800 callbacks that will get called whenever a field or file is parsed.
1801
1802 Args:
1803 headers: A dictionary-like object of HTTP headers. The only required header is Content-Type.
1804 input_stream: A file-like object that represents the request body. The read() method must return bytestrings.
1805 on_field: Callback to call with each parsed field.
1806 on_file: Callback to call with each parsed file.
1807 chunk_size: The maximum size to read from the input stream and write to the parser at one time.
1808 Defaults to 1 MiB.
1809 """
1810 # Create our form parser.
1811 parser = create_form_parser(headers, on_field, on_file)
1812
1813 # Read chunks of 1MiB and write to the parser, but never read more than
1814 # the given Content-Length, if any.
1815 content_length = headers.get("Content-Length")
1816 if content_length is not None:
1817 content_length = int(content_length)
1818 else:
1819 content_length = float("inf")
1820 bytes_read = 0
1821
1822 while True:
1823 # Read only up to the Content-Length given.
1824 max_readable = min(content_length - bytes_read, chunk_size)
1825 buff = input_stream.read(max_readable)
1826
1827 # Write to the parser and update our length.
1828 parser.write(buff)
1829 bytes_read += len(buff)
1830
1831 # If we get a buffer that's smaller than the size requested, or if we
1832 # have read up to our content length, we're done.
1833 if len(buff) != max_readable or bytes_read == content_length:
1834 break
1835
1836 # Tell our parser that we're done writing data.
1837 parser.finalize()