1"""
2This module is for codecs only.
3
4While the codec implementation can contain details of the PDF specification,
5the module should not do any PDF parsing.
6"""
7
8import io
9from abc import ABC, abstractmethod
10
11from pypdf._utils import logger_warning
12from pypdf.errors import LimitReachedError
13
14
15class Codec(ABC):
16 """Abstract base class for all codecs."""
17
18 @abstractmethod
19 def encode(self, data: bytes) -> bytes:
20 """
21 Encode the input data.
22
23 Args:
24 data: Data to encode.
25
26 Returns:
27 Encoded data.
28
29 """
30
31 @abstractmethod
32 def decode(self, data: bytes) -> bytes:
33 """
34 Decode the input data.
35
36 Args:
37 data: Data to decode.
38
39 Returns:
40 Decoded data.
41
42 """
43
44
45class LzwCodec(Codec):
46 """Lempel-Ziv-Welch (LZW) adaptive compression codec."""
47
48 CLEAR_TABLE_MARKER = 256 # Special code to indicate table reset
49 EOD_MARKER = 257 # End-of-data marker
50 INITIAL_BITS_PER_CODE = 9 # Initial code bit width
51 MAX_BITS_PER_CODE = 12 # Maximum code bit width
52
53 def __init__(self, max_output_length: int = 1_000_000_000) -> None:
54 self.max_output_length = max_output_length
55
56 def _initialize_encoding_table(self) -> None:
57 """Initialize the encoding table and state to initial conditions."""
58 self.encoding_table: dict[bytes, int] = {bytes([i]): i for i in range(256)}
59 self.next_code = self.EOD_MARKER + 1
60 self.bits_per_code = self.INITIAL_BITS_PER_CODE
61 self.max_code_value = (1 << self.bits_per_code) - 1
62
63 def _increase_next_code(self) -> None:
64 """Update bits_per_code and max_code_value if necessary."""
65 self.next_code += 1
66 if (
67 self.next_code > self.max_code_value
68 and self.bits_per_code < self.MAX_BITS_PER_CODE
69 ):
70 self.bits_per_code += 1
71 self.max_code_value = (1 << self.bits_per_code) - 1
72
73 def encode(self, data: bytes) -> bytes:
74 """
75 Encode data using the LZW compression algorithm.
76
77 Taken from PDF 1.7 specs, "7.4.4.2 Details of LZW Encoding".
78 """
79 result_codes: list[int] = []
80
81 # The encoder shall begin by issuing a clear-table code
82 result_codes.append(self.CLEAR_TABLE_MARKER)
83 self._initialize_encoding_table()
84
85 current_sequence = b""
86 for byte in data:
87 next_sequence = current_sequence + bytes([byte])
88
89 if next_sequence in self.encoding_table:
90 # Extend current sequence if already in the table
91 current_sequence = next_sequence
92 else:
93 # Output code for the current sequence
94 result_codes.append(self.encoding_table[current_sequence])
95
96 # Add the new sequence to the table if there's room
97 if self.next_code <= (1 << self.MAX_BITS_PER_CODE) - 1:
98 self.encoding_table[next_sequence] = self.next_code
99 self._increase_next_code()
100 else:
101 # If the table is full, emit a clear-table command
102 result_codes.append(self.CLEAR_TABLE_MARKER)
103 self._initialize_encoding_table()
104
105 # Start new sequence
106 current_sequence = bytes([byte])
107
108 # Ensure everything actually is encoded
109 if current_sequence:
110 result_codes.append(self.encoding_table[current_sequence])
111 result_codes.append(self.EOD_MARKER)
112
113 return self._pack_codes_into_bytes(result_codes)
114
115 def _pack_codes_into_bytes(self, codes: list[int]) -> bytes:
116 """
117 Convert the list of result codes into a continuous byte stream, with codes packed as per the code bit-width.
118 The bit-width starts at 9 bits and expands as needed.
119 """
120 self._initialize_encoding_table()
121 buffer = 0
122 bits_in_buffer = 0
123 output = bytearray()
124
125 for code in codes:
126 buffer = (buffer << self.bits_per_code) | code
127 bits_in_buffer += self.bits_per_code
128
129 # Codes shall be packed into a continuous bit stream, high-order bit
130 # first. This stream shall then be divided into bytes, high-order bit
131 # first.
132 while bits_in_buffer >= 8:
133 bits_in_buffer -= 8
134 output.append((buffer >> bits_in_buffer) & 0xFF)
135
136 if code == self.CLEAR_TABLE_MARKER:
137 self._initialize_encoding_table()
138 elif code == self.EOD_MARKER:
139 continue
140 else:
141 self._increase_next_code()
142
143 # Flush any remaining bits in the buffer
144 if bits_in_buffer > 0:
145 output.append((buffer << (8 - bits_in_buffer)) & 0xFF)
146
147 return bytes(output)
148
149 def _initialize_decoding_table(self) -> None:
150 self.max_code_value = (1 << self.MAX_BITS_PER_CODE) - 1
151 self.decoding_table = [bytes([i]) for i in range(self.CLEAR_TABLE_MARKER)] + [
152 b""
153 ] * (self.max_code_value - self.CLEAR_TABLE_MARKER + 1)
154 self._table_index = self.EOD_MARKER + 1
155 self._bits_to_get = 9
156
157 def _next_code_decode(self, data: bytes) -> int:
158 self._next_data: int
159 try:
160 while self._next_bits < self._bits_to_get:
161 self._next_data = (self._next_data << 8) | (
162 data[self._byte_pointer]
163 )
164 self._byte_pointer += 1
165 self._next_bits += 8
166
167 code = (
168 self._next_data >> (self._next_bits - self._bits_to_get)
169 ) & self._and_table[self._bits_to_get - 9]
170 self._next_bits -= self._bits_to_get
171
172 # Reduce data to get rid of the overhead,
173 # which increases performance on large streams significantly.
174 self._next_data = self._next_data & 0xFFFFF
175
176 return code
177 except IndexError:
178 return self.EOD_MARKER
179
180 # The following method has been converted to Python from PDFsharp:
181 # https://github.com/empira/PDFsharp/blob/5fbf6ed14740bc4e16786816882d32e43af3ff5d/src/foundation/src/PDFsharp/src/PdfSharp/Pdf.Filters/LzwDecode.cs
182 #
183 # Original license:
184 #
185 # -------------------------------------------------------------------------
186 # Copyright (c) 2001-2024 empira Software GmbH, Troisdorf (Cologne Area),
187 # Germany
188 #
189 # http://docs.pdfsharp.net
190 #
191 # MIT License
192 #
193 # Permission is hereby granted, free of charge, to any person obtaining a
194 # copy of this software and associated documentation files (the "Software"),
195 # to deal in the Software without restriction, including without limitation
196 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
197 # and/or sell copies of the Software, and to permit persons to whom the
198 # Software is furnished to do so, subject to the following conditions:
199 #
200 # The above copyright notice and this permission notice shall be included
201 # in all copies or substantial portions of the Software.
202 #
203 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
204 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
205 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
206 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
207 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
208 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
209 # DEALINGS IN THE SOFTWARE.
210 # --------------------------------------------------------------------------
211 def decode(self, data: bytes) -> bytes:
212 """
213 The following code was converted to Python from the following code:
214 https://github.com/empira/PDFsharp/blob/master/src/foundation/src/PDFsharp/src/PdfSharp/Pdf.Filters/LzwDecode.cs
215 """
216 self._and_table = [511, 1023, 2047, 4095]
217 self._table_index = 0
218 self._bits_to_get = 9
219 self._byte_pointer = 0
220 self._next_data = 0
221 self._next_bits = 0
222
223 output_stream = io.BytesIO()
224 output_length = 0
225
226 self._initialize_decoding_table()
227 self._byte_pointer = 0
228 self._next_data = 0
229 self._next_bits = 0
230 old_code = self.CLEAR_TABLE_MARKER
231
232 while True:
233 code = self._next_code_decode(data)
234 if code == self.EOD_MARKER:
235 break
236
237 if code == self.CLEAR_TABLE_MARKER:
238 self._initialize_decoding_table()
239 code = self._next_code_decode(data)
240 if code == self.EOD_MARKER:
241 break
242 output_stream.write(decoded := self.decoding_table[code])
243 old_code = code
244 elif code < self._table_index:
245 decoded = self.decoding_table[code]
246 output_stream.write(decoded)
247 if old_code != self.CLEAR_TABLE_MARKER:
248 self._add_entry_decode(self.decoding_table[old_code], decoded[0])
249 old_code = code
250 else:
251 # The code is not in the table and not one of the special codes
252 decoded = (
253 self.decoding_table[old_code] + self.decoding_table[old_code][:1]
254 )
255 output_stream.write(decoded)
256 self._add_entry_decode(self.decoding_table[old_code], decoded[0])
257 old_code = code
258
259 output_length += len(decoded)
260 if output_length > self.max_output_length:
261 raise LimitReachedError(
262 f"Limit reached while decompressing: {output_length} > {self.max_output_length}"
263 )
264
265 return output_stream.getvalue()
266
267 def _add_entry_decode(self, old_string: bytes, new_char: int) -> None:
268 new_string = old_string + bytes([new_char])
269 if self._table_index > self.max_code_value:
270 logger_warning("Ignoring too large LZW table index.", __name__)
271 return
272 self.decoding_table[self._table_index] = new_string
273 self._table_index += 1
274
275 # Update the number of bits to get based on the table index
276 if self._table_index == 511:
277 self._bits_to_get = 10
278 elif self._table_index == 1023:
279 self._bits_to_get = 11
280 elif self._table_index == 2047:
281 self._bits_to_get = 12