1""" 
    2This module is for codecs only. 
    3 
    4While the codec implementation can contain details of the PDF specification, 
    5the module should not do any PDF parsing. 
    6""" 
    7 
    8import io 
    9from abc import ABC, abstractmethod 
    10 
    11from pypdf._utils import logger_warning 
    12from pypdf.errors import LimitReachedError 
    13 
    14 
    15class Codec(ABC): 
    16    """Abstract base class for all codecs.""" 
    17 
    18    @abstractmethod 
    19    def encode(self, data: bytes) -> bytes: 
    20        """ 
    21        Encode the input data. 
    22 
    23        Args: 
    24            data: Data to encode. 
    25 
    26        Returns: 
    27            Encoded data. 
    28 
    29        """ 
    30 
    31    @abstractmethod 
    32    def decode(self, data: bytes) -> bytes: 
    33        """ 
    34        Decode the input data. 
    35 
    36        Args: 
    37            data: Data to decode. 
    38 
    39        Returns: 
    40            Decoded data. 
    41 
    42        """ 
    43 
    44 
    45class LzwCodec(Codec): 
    46    """Lempel-Ziv-Welch (LZW) adaptive compression codec.""" 
    47 
    48    CLEAR_TABLE_MARKER = 256  # Special code to indicate table reset 
    49    EOD_MARKER = 257  # End-of-data marker 
    50    INITIAL_BITS_PER_CODE = 9  # Initial code bit width 
    51    MAX_BITS_PER_CODE = 12  # Maximum code bit width 
    52 
    53    def __init__(self, max_output_length: int = 1_000_000_000) -> None: 
    54        self.max_output_length = max_output_length 
    55 
    56    def _initialize_encoding_table(self) -> None: 
    57        """Initialize the encoding table and state to initial conditions.""" 
    58        self.encoding_table: dict[bytes, int] = {bytes([i]): i for i in range(256)} 
    59        self.next_code = self.EOD_MARKER + 1 
    60        self.bits_per_code = self.INITIAL_BITS_PER_CODE 
    61        self.max_code_value = (1 << self.bits_per_code) - 1 
    62 
    63    def _increase_next_code(self) -> None: 
    64        """Update bits_per_code and max_code_value if necessary.""" 
    65        self.next_code += 1 
    66        if ( 
    67            self.next_code > self.max_code_value 
    68            and self.bits_per_code < self.MAX_BITS_PER_CODE 
    69        ): 
    70            self.bits_per_code += 1 
    71            self.max_code_value = (1 << self.bits_per_code) - 1 
    72 
    73    def encode(self, data: bytes) -> bytes: 
    74        """ 
    75        Encode data using the LZW compression algorithm. 
    76 
    77        Taken from PDF 1.7 specs, "7.4.4.2 Details of LZW Encoding". 
    78        """ 
    79        result_codes: list[int] = [] 
    80 
    81        # The encoder shall begin by issuing a clear-table code 
    82        result_codes.append(self.CLEAR_TABLE_MARKER) 
    83        self._initialize_encoding_table() 
    84 
    85        current_sequence = b"" 
    86        for byte in data: 
    87            next_sequence = current_sequence + bytes([byte]) 
    88 
    89            if next_sequence in self.encoding_table: 
    90                # Extend current sequence if already in the table 
    91                current_sequence = next_sequence 
    92            else: 
    93                # Output code for the current sequence 
    94                result_codes.append(self.encoding_table[current_sequence]) 
    95 
    96                # Add the new sequence to the table if there's room 
    97                if self.next_code <= (1 << self.MAX_BITS_PER_CODE) - 1: 
    98                    self.encoding_table[next_sequence] = self.next_code 
    99                    self._increase_next_code() 
    100                else: 
    101                    # If the table is full, emit a clear-table command 
    102                    result_codes.append(self.CLEAR_TABLE_MARKER) 
    103                    self._initialize_encoding_table() 
    104 
    105                # Start new sequence 
    106                current_sequence = bytes([byte]) 
    107 
    108        # Ensure everything actually is encoded 
    109        if current_sequence: 
    110            result_codes.append(self.encoding_table[current_sequence]) 
    111        result_codes.append(self.EOD_MARKER) 
    112 
    113        return self._pack_codes_into_bytes(result_codes) 
    114 
    115    def _pack_codes_into_bytes(self, codes: list[int]) -> bytes: 
    116        """ 
    117        Convert the list of result codes into a continuous byte stream, with codes packed as per the code bit-width. 
    118        The bit-width starts at 9 bits and expands as needed. 
    119        """ 
    120        self._initialize_encoding_table() 
    121        buffer = 0 
    122        bits_in_buffer = 0 
    123        output = bytearray() 
    124 
    125        for code in codes: 
    126            buffer = (buffer << self.bits_per_code) | code 
    127            bits_in_buffer += self.bits_per_code 
    128 
    129            # Codes shall be packed into a continuous bit stream, high-order bit 
    130            # first. This stream shall then be divided into bytes, high-order bit 
    131            # first. 
    132            while bits_in_buffer >= 8: 
    133                bits_in_buffer -= 8 
    134                output.append((buffer >> bits_in_buffer) & 0xFF) 
    135 
    136            if code == self.CLEAR_TABLE_MARKER: 
    137                self._initialize_encoding_table() 
    138            elif code == self.EOD_MARKER: 
    139                continue 
    140            else: 
    141                self._increase_next_code() 
    142 
    143        # Flush any remaining bits in the buffer 
    144        if bits_in_buffer > 0: 
    145            output.append((buffer << (8 - bits_in_buffer)) & 0xFF) 
    146 
    147        return bytes(output) 
    148 
    149    def _initialize_decoding_table(self) -> None: 
    150        self.max_code_value = (1 << self.MAX_BITS_PER_CODE) - 1 
    151        self.decoding_table = [bytes([i]) for i in range(self.CLEAR_TABLE_MARKER)] + [ 
    152            b"" 
    153        ] * (self.max_code_value - self.CLEAR_TABLE_MARKER + 1) 
    154        self._table_index = self.EOD_MARKER + 1 
    155        self._bits_to_get = 9 
    156 
    157    def _next_code_decode(self, data: bytes) -> int: 
    158        self._next_data: int 
    159        try: 
    160            while self._next_bits < self._bits_to_get: 
    161                self._next_data = (self._next_data << 8) | ( 
    162                    data[self._byte_pointer] 
    163                ) 
    164                self._byte_pointer += 1 
    165                self._next_bits += 8 
    166 
    167            code = ( 
    168                self._next_data >> (self._next_bits - self._bits_to_get) 
    169            ) & self._and_table[self._bits_to_get - 9] 
    170            self._next_bits -= self._bits_to_get 
    171 
    172            # Reduce data to get rid of the overhead, 
    173            # which increases performance on large streams significantly. 
    174            self._next_data = self._next_data & 0xFFFFF 
    175 
    176            return code 
    177        except IndexError: 
    178            return self.EOD_MARKER 
    179 
    180    # The following method has been converted to Python from PDFsharp: 
    181    # https://github.com/empira/PDFsharp/blob/5fbf6ed14740bc4e16786816882d32e43af3ff5d/src/foundation/src/PDFsharp/src/PdfSharp/Pdf.Filters/LzwDecode.cs 
    182    # 
    183    # Original license: 
    184    # 
    185    # ------------------------------------------------------------------------- 
    186    # Copyright (c) 2001-2024 empira Software GmbH, Troisdorf (Cologne Area), 
    187    # Germany 
    188    # 
    189    # http://docs.pdfsharp.net 
    190    # 
    191    # MIT License 
    192    # 
    193    # Permission is hereby granted, free of charge, to any person obtaining a 
    194    # copy of this software and associated documentation files (the "Software"), 
    195    # to deal in the Software without restriction, including without limitation 
    196    # the rights to use, copy, modify, merge, publish, distribute, sublicense, 
    197    # and/or sell copies of the Software, and to permit persons to whom the 
    198    # Software is furnished to do so, subject to the following conditions: 
    199    # 
    200    # The above copyright notice and this permission notice shall be included 
    201    # in all copies or substantial portions of the Software. 
    202    # 
    203    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 
    204    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 
    205    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 
    206    # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 
    207    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
    208    # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 
    209    # DEALINGS IN THE SOFTWARE. 
    210    # -------------------------------------------------------------------------- 
    211    def decode(self, data: bytes) -> bytes: 
    212        """ 
    213        The following code was converted to Python from the following code: 
    214        https://github.com/empira/PDFsharp/blob/master/src/foundation/src/PDFsharp/src/PdfSharp/Pdf.Filters/LzwDecode.cs 
    215        """ 
    216        self._and_table = [511, 1023, 2047, 4095] 
    217        self._table_index = 0 
    218        self._bits_to_get = 9 
    219        self._byte_pointer = 0 
    220        self._next_data = 0 
    221        self._next_bits = 0 
    222 
    223        output_stream = io.BytesIO() 
    224        output_length = 0 
    225 
    226        self._initialize_decoding_table() 
    227        self._byte_pointer = 0 
    228        self._next_data = 0 
    229        self._next_bits = 0 
    230        old_code = self.CLEAR_TABLE_MARKER 
    231 
    232        while True: 
    233            code = self._next_code_decode(data) 
    234            if code == self.EOD_MARKER: 
    235                break 
    236 
    237            if code == self.CLEAR_TABLE_MARKER: 
    238                self._initialize_decoding_table() 
    239                code = self._next_code_decode(data) 
    240                if code == self.EOD_MARKER: 
    241                    break 
    242                output_stream.write(decoded := self.decoding_table[code]) 
    243                old_code = code 
    244            elif code < self._table_index: 
    245                decoded = self.decoding_table[code] 
    246                output_stream.write(decoded) 
    247                if old_code != self.CLEAR_TABLE_MARKER: 
    248                    self._add_entry_decode(self.decoding_table[old_code], decoded[0]) 
    249                old_code = code 
    250            else: 
    251                # The code is not in the table and not one of the special codes 
    252                decoded = ( 
    253                    self.decoding_table[old_code] + self.decoding_table[old_code][:1] 
    254                ) 
    255                output_stream.write(decoded) 
    256                self._add_entry_decode(self.decoding_table[old_code], decoded[0]) 
    257                old_code = code 
    258 
    259            output_length += len(decoded) 
    260            if output_length > self.max_output_length: 
    261                raise LimitReachedError( 
    262                    f"Limit reached while decompressing: {output_length} > {self.max_output_length}" 
    263                ) 
    264 
    265        return output_stream.getvalue() 
    266 
    267    def _add_entry_decode(self, old_string: bytes, new_char: int) -> None: 
    268        new_string = old_string + bytes([new_char]) 
    269        if self._table_index > self.max_code_value: 
    270            logger_warning("Ignoring too large LZW table index.", __name__) 
    271            return 
    272        self.decoding_table[self._table_index] = new_string 
    273        self._table_index += 1 
    274 
    275        # Update the number of bits to get based on the table index 
    276        if self._table_index == 511: 
    277            self._bits_to_get = 10 
    278        elif self._table_index == 1023: 
    279            self._bits_to_get = 11 
    280        elif self._table_index == 2047: 
    281            self._bits_to_get = 12