1from __future__ import annotations 
    2 
    3import re 
    4import typing as t 
    5from dataclasses import dataclass 
    6from enum import auto 
    7from enum import Enum 
    8 
    9from ..datastructures import Headers 
    10from ..exceptions import RequestEntityTooLarge 
    11from ..http import parse_options_header 
    12 
    13 
    14class Event: 
    15    pass 
    16 
    17 
    18@dataclass(frozen=True) 
    19class Preamble(Event): 
    20    data: bytes 
    21 
    22 
    23@dataclass(frozen=True) 
    24class Field(Event): 
    25    name: str 
    26    headers: Headers 
    27 
    28 
    29@dataclass(frozen=True) 
    30class File(Event): 
    31    name: str 
    32    filename: str 
    33    headers: Headers 
    34 
    35 
    36@dataclass(frozen=True) 
    37class Data(Event): 
    38    data: bytes 
    39    more_data: bool 
    40 
    41 
    42@dataclass(frozen=True) 
    43class Epilogue(Event): 
    44    data: bytes 
    45 
    46 
    47class NeedData(Event): 
    48    pass 
    49 
    50 
    51NEED_DATA = NeedData() 
    52 
    53 
    54class State(Enum): 
    55    PREAMBLE = auto() 
    56    PART = auto() 
    57    DATA = auto() 
    58    DATA_START = auto() 
    59    EPILOGUE = auto() 
    60    COMPLETE = auto() 
    61 
    62 
    63# Multipart line breaks MUST be CRLF (\r\n) by RFC-7578, except that 
    64# many implementations break this and either use CR or LF alone. 
    65LINE_BREAK = b"(?:\r\n|\n|\r)" 
    66BLANK_LINE_RE = re.compile(b"(?:\r\n\r\n|\r\r|\n\n)", re.MULTILINE) 
    67LINE_BREAK_RE = re.compile(LINE_BREAK, re.MULTILINE) 
    68# Header values can be continued via a space or tab after the linebreak, as 
    69# per RFC2231 
    70HEADER_CONTINUATION_RE = re.compile(b"%s[ \t]" % LINE_BREAK, re.MULTILINE) 
    71# This must be long enough to contain any line breaks plus any 
    72# additional boundary markers (--) such that they will be found in a 
    73# subsequent search 
    74SEARCH_EXTRA_LENGTH = 8 
    75 
    76 
    77class MultipartDecoder: 
    78    """Decodes a multipart message as bytes into Python events. 
    79 
    80    The part data is returned as available to allow the caller to save 
    81    the data from memory to disk, if desired. 
    82    """ 
    83 
    84    def __init__( 
    85        self, 
    86        boundary: bytes, 
    87        max_form_memory_size: int | None = None, 
    88        *, 
    89        max_parts: int | None = None, 
    90    ) -> None: 
    91        self.buffer = bytearray() 
    92        self.complete = False 
    93        self.max_form_memory_size = max_form_memory_size 
    94        self.max_parts = max_parts 
    95        self.state = State.PREAMBLE 
    96        self.boundary = boundary 
    97 
    98        # Note in the below \h i.e. horizontal whitespace is used 
    99        # as [^\S\n\r] as \h isn't supported in python. 
    100 
    101        # The preamble must end with a boundary where the boundary is 
    102        # prefixed by a line break, RFC2046. Except that many 
    103        # implementations including Werkzeug's tests omit the line 
    104        # break prefix. In addition the first boundary could be the 
    105        # epilogue boundary (for empty form-data) hence the matching 
    106        # group to understand if it is an epilogue boundary. 
    107        self.preamble_re = re.compile( 
    108            rb"%s?--%s(--[^\S\n\r]*%s?|[^\S\n\r]*%s)" 
    109            % (LINE_BREAK, re.escape(boundary), LINE_BREAK, LINE_BREAK), 
    110            re.MULTILINE, 
    111        ) 
    112        # A boundary must include a line break prefix and suffix, and 
    113        # may include trailing whitespace. In addition the boundary 
    114        # could be the epilogue boundary hence the matching group to 
    115        # understand if it is an epilogue boundary. 
    116        self.boundary_re = re.compile( 
    117            rb"%s--%s(--[^\S\n\r]*%s?|[^\S\n\r]*%s)" 
    118            % (LINE_BREAK, re.escape(boundary), LINE_BREAK, LINE_BREAK), 
    119            re.MULTILINE, 
    120        ) 
    121        self._search_position = 0 
    122        self._parts_decoded = 0 
    123 
    124    def last_newline(self, data: bytes) -> int: 
    125        try: 
    126            last_nl = data.rindex(b"\n") 
    127        except ValueError: 
    128            last_nl = len(data) 
    129        try: 
    130            last_cr = data.rindex(b"\r") 
    131        except ValueError: 
    132            last_cr = len(data) 
    133 
    134        return min(last_nl, last_cr) 
    135 
    136    def receive_data(self, data: bytes | None) -> None: 
    137        if data is None: 
    138            self.complete = True 
    139        elif ( 
    140            self.max_form_memory_size is not None 
    141            and len(self.buffer) + len(data) > self.max_form_memory_size 
    142        ): 
    143            raise RequestEntityTooLarge() 
    144        else: 
    145            self.buffer.extend(data) 
    146 
    147    def next_event(self) -> Event: 
    148        event: Event = NEED_DATA 
    149 
    150        if self.state == State.PREAMBLE: 
    151            match = self.preamble_re.search(self.buffer, self._search_position) 
    152            if match is not None: 
    153                if match.group(1).startswith(b"--"): 
    154                    self.state = State.EPILOGUE 
    155                else: 
    156                    self.state = State.PART 
    157                data = bytes(self.buffer[: match.start()]) 
    158                del self.buffer[: match.end()] 
    159                event = Preamble(data=data) 
    160                self._search_position = 0 
    161            else: 
    162                # Update the search start position to be equal to the 
    163                # current buffer length (already searched) minus a 
    164                # safe buffer for part of the search target. 
    165                self._search_position = max( 
    166                    0, len(self.buffer) - len(self.boundary) - SEARCH_EXTRA_LENGTH 
    167                ) 
    168 
    169        elif self.state == State.PART: 
    170            match = BLANK_LINE_RE.search(self.buffer, self._search_position) 
    171            if match is not None: 
    172                headers = self._parse_headers(self.buffer[: match.start()]) 
    173                # The final header ends with a single CRLF, however a 
    174                # blank line indicates the start of the 
    175                # body. Therefore the end is after the first CRLF. 
    176                headers_end = (match.start() + match.end()) // 2 
    177                del self.buffer[:headers_end] 
    178 
    179                if "content-disposition" not in headers: 
    180                    raise ValueError("Missing Content-Disposition header") 
    181 
    182                disposition, extra = parse_options_header( 
    183                    headers["content-disposition"] 
    184                ) 
    185                name = t.cast(str, extra.get("name")) 
    186                filename = extra.get("filename") 
    187                if filename is not None: 
    188                    event = File( 
    189                        filename=filename, 
    190                        headers=headers, 
    191                        name=name, 
    192                    ) 
    193                else: 
    194                    event = Field( 
    195                        headers=headers, 
    196                        name=name, 
    197                    ) 
    198                self.state = State.DATA_START 
    199                self._search_position = 0 
    200                self._parts_decoded += 1 
    201 
    202                if self.max_parts is not None and self._parts_decoded > self.max_parts: 
    203                    raise RequestEntityTooLarge() 
    204            else: 
    205                # Update the search start position to be equal to the 
    206                # current buffer length (already searched) minus a 
    207                # safe buffer for part of the search target. 
    208                self._search_position = max(0, len(self.buffer) - SEARCH_EXTRA_LENGTH) 
    209 
    210        elif self.state == State.DATA_START: 
    211            data, del_index, more_data = self._parse_data(self.buffer, start=True) 
    212            del self.buffer[:del_index] 
    213            event = Data(data=data, more_data=more_data) 
    214            if more_data: 
    215                self.state = State.DATA 
    216 
    217        elif self.state == State.DATA: 
    218            data, del_index, more_data = self._parse_data(self.buffer, start=False) 
    219            del self.buffer[:del_index] 
    220            if data or not more_data: 
    221                event = Data(data=data, more_data=more_data) 
    222 
    223        elif self.state == State.EPILOGUE and self.complete: 
    224            event = Epilogue(data=bytes(self.buffer)) 
    225            del self.buffer[:] 
    226            self.state = State.COMPLETE 
    227 
    228        if self.complete and isinstance(event, NeedData): 
    229            raise ValueError(f"Invalid form-data cannot parse beyond {self.state}") 
    230 
    231        return event 
    232 
    233    def _parse_headers(self, data: bytes) -> Headers: 
    234        headers: list[tuple[str, str]] = [] 
    235        # Merge the continued headers into one line 
    236        data = HEADER_CONTINUATION_RE.sub(b" ", data) 
    237        # Now there is one header per line 
    238        for line in data.splitlines(): 
    239            line = line.strip() 
    240 
    241            if line != b"": 
    242                name, _, value = line.decode().partition(":") 
    243                headers.append((name.strip(), value.strip())) 
    244        return Headers(headers) 
    245 
    246    def _parse_data(self, data: bytes, *, start: bool) -> tuple[bytes, int, bool]: 
    247        # Body parts must start with CRLF (or CR or LF) 
    248        if start: 
    249            match = LINE_BREAK_RE.match(data) 
    250            data_start = t.cast(t.Match[bytes], match).end() 
    251        else: 
    252            data_start = 0 
    253 
    254        boundary = b"--" + self.boundary 
    255 
    256        if self.buffer.find(boundary) == -1: 
    257            # No complete boundary in the buffer, but there may be 
    258            # a partial boundary at the end. As the boundary 
    259            # starts with either a nl or cr find the earliest and 
    260            # return up to that as data. 
    261            data_end = del_index = self.last_newline(data[data_start:]) + data_start 
    262            # If amount of data after last newline is far from 
    263            # possible length of partial boundary, we should 
    264            # assume that there is no partial boundary in the buffer 
    265            # and return all pending data. 
    266            if (len(data) - data_end) > len(b"\n" + boundary): 
    267                data_end = del_index = len(data) 
    268            more_data = True 
    269        else: 
    270            match = self.boundary_re.search(data) 
    271            if match is not None: 
    272                if match.group(1).startswith(b"--"): 
    273                    self.state = State.EPILOGUE 
    274                else: 
    275                    self.state = State.PART 
    276                data_end = match.start() 
    277                del_index = match.end() 
    278            else: 
    279                data_end = del_index = self.last_newline(data[data_start:]) + data_start 
    280            more_data = match is None 
    281 
    282        return bytes(data[data_start:data_end]), del_index, more_data 
    283 
    284 
    285class MultipartEncoder: 
    286    def __init__(self, boundary: bytes) -> None: 
    287        self.boundary = boundary 
    288        self.state = State.PREAMBLE 
    289 
    290    def send_event(self, event: Event) -> bytes: 
    291        if isinstance(event, Preamble) and self.state == State.PREAMBLE: 
    292            self.state = State.PART 
    293            return event.data 
    294        elif isinstance(event, (Field, File)) and self.state in { 
    295            State.PREAMBLE, 
    296            State.PART, 
    297            State.DATA, 
    298        }: 
    299            data = b"\r\n--" + self.boundary + b"\r\n" 
    300            data += b'Content-Disposition: form-data; name="%s"' % event.name.encode() 
    301            if isinstance(event, File): 
    302                data += b'; filename="%s"' % event.filename.encode() 
    303            data += b"\r\n" 
    304            for name, value in t.cast(Field, event).headers: 
    305                if name.lower() != "content-disposition": 
    306                    data += f"{name}: {value}\r\n".encode() 
    307            self.state = State.DATA_START 
    308            return data 
    309        elif isinstance(event, Data) and self.state == State.DATA_START: 
    310            self.state = State.DATA 
    311            if len(event.data) > 0: 
    312                return b"\r\n" + event.data 
    313            else: 
    314                return event.data 
    315        elif isinstance(event, Data) and self.state == State.DATA: 
    316            return event.data 
    317        elif isinstance(event, Epilogue): 
    318            self.state = State.COMPLETE 
    319            return b"\r\n--" + self.boundary + b"--\r\n" + event.data 
    320        else: 
    321            raise ValueError(f"Cannot generate {event} in state: {self.state}")