1import asyncio
2import sys
3import zlib
4from abc import ABC, abstractmethod
5from concurrent.futures import Executor
6from typing import Any, Final, Protocol, TypedDict, cast
7
8if sys.version_info >= (3, 12):
9 from collections.abc import Buffer
10else:
11 from typing import Union
12
13 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
14
15try:
16 try:
17 import brotlicffi as brotli
18 except ImportError:
19 import brotli
20
21 HAS_BROTLI = True
22except ImportError: # pragma: no cover
23 HAS_BROTLI = False
24
25try:
26 if sys.version_info >= (3, 14):
27 from compression.zstd import ZstdDecompressor # noqa: I900
28 else: # TODO(PY314): Remove mentions of backports.zstd across codebase
29 from backports.zstd import ZstdDecompressor
30
31 HAS_ZSTD = True
32except ImportError:
33 HAS_ZSTD = False
34
35
36MAX_SYNC_CHUNK_SIZE = 4096
37
38# Unlimited decompression constants - different libraries use different conventions
39ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
40ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
41
42
43class ZLibCompressObjProtocol(Protocol):
44 def compress(self, data: Buffer) -> bytes: ...
45 def flush(self, mode: int = ..., /) -> bytes: ...
46
47
48class ZLibDecompressObjProtocol(Protocol):
49 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
50 def flush(self, length: int = ..., /) -> bytes: ...
51
52 @property
53 def eof(self) -> bool: ...
54
55 @property
56 def unconsumed_tail(self) -> bytes: ...
57
58 @property
59 def unused_data(self) -> bytes: ...
60
61
62class ZLibBackendProtocol(Protocol):
63 MAX_WBITS: int
64 Z_FULL_FLUSH: int
65 Z_SYNC_FLUSH: int
66 Z_BEST_SPEED: int
67 Z_FINISH: int
68
69 def compressobj(
70 self,
71 level: int = ...,
72 method: int = ...,
73 wbits: int = ...,
74 memLevel: int = ...,
75 strategy: int = ...,
76 zdict: Buffer | None = ...,
77 ) -> ZLibCompressObjProtocol: ...
78 def decompressobj(
79 self, wbits: int = ..., zdict: Buffer = ...
80 ) -> ZLibDecompressObjProtocol: ...
81
82 def compress(
83 self, data: Buffer, /, level: int = ..., wbits: int = ...
84 ) -> bytes: ...
85 def decompress(
86 self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
87 ) -> bytes: ...
88
89
90class CompressObjArgs(TypedDict, total=False):
91 wbits: int
92 strategy: int
93 level: int
94
95
96class ZLibBackendWrapper:
97 def __init__(self, _zlib_backend: ZLibBackendProtocol):
98 self._zlib_backend: ZLibBackendProtocol = _zlib_backend
99
100 @property
101 def name(self) -> str:
102 return getattr(self._zlib_backend, "__name__", "undefined")
103
104 @property
105 def MAX_WBITS(self) -> int:
106 return self._zlib_backend.MAX_WBITS
107
108 @property
109 def Z_FULL_FLUSH(self) -> int:
110 return self._zlib_backend.Z_FULL_FLUSH
111
112 @property
113 def Z_SYNC_FLUSH(self) -> int:
114 return self._zlib_backend.Z_SYNC_FLUSH
115
116 @property
117 def Z_BEST_SPEED(self) -> int:
118 return self._zlib_backend.Z_BEST_SPEED
119
120 @property
121 def Z_FINISH(self) -> int:
122 return self._zlib_backend.Z_FINISH
123
124 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
125 return self._zlib_backend.compressobj(*args, **kwargs)
126
127 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
128 return self._zlib_backend.decompressobj(*args, **kwargs)
129
130 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
131 return self._zlib_backend.compress(data, *args, **kwargs)
132
133 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
134 return self._zlib_backend.decompress(data, *args, **kwargs)
135
136 # Everything not explicitly listed in the Protocol we just pass through
137 def __getattr__(self, attrname: str) -> Any:
138 return getattr(self._zlib_backend, attrname)
139
140
141ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
142
143
144def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
145 ZLibBackend._zlib_backend = new_zlib_backend
146
147
148def encoding_to_mode(
149 encoding: str | None = None,
150 suppress_deflate_header: bool = False,
151) -> int:
152 if encoding == "gzip":
153 return 16 + ZLibBackend.MAX_WBITS
154
155 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
156
157
158class DecompressionBaseHandler(ABC):
159 def __init__(
160 self,
161 executor: Executor | None = None,
162 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
163 ):
164 """Base class for decompression handlers."""
165 self._executor = executor
166 self._max_sync_chunk_size = max_sync_chunk_size
167
168 @abstractmethod
169 def decompress_sync(
170 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
171 ) -> bytes:
172 """Decompress the given data."""
173
174 async def decompress(
175 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
176 ) -> bytes:
177 """Decompress the given data."""
178 if (
179 self._max_sync_chunk_size is not None
180 and len(data) > self._max_sync_chunk_size
181 ):
182 return await asyncio.get_event_loop().run_in_executor(
183 self._executor, self.decompress_sync, data, max_length
184 )
185 return self.decompress_sync(data, max_length)
186
187 @property
188 @abstractmethod
189 def data_available(self) -> bool:
190 """Return True if more output is available by passing b""."""
191
192
193class ZLibCompressor:
194 def __init__(
195 self,
196 encoding: str | None = None,
197 suppress_deflate_header: bool = False,
198 level: int | None = None,
199 wbits: int | None = None,
200 strategy: int | None = None,
201 executor: Executor | None = None,
202 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
203 ):
204 self._executor = executor
205 self._max_sync_chunk_size = max_sync_chunk_size
206 self._mode = (
207 encoding_to_mode(encoding, suppress_deflate_header)
208 if wbits is None
209 else wbits
210 )
211 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
212
213 kwargs: CompressObjArgs = {}
214 kwargs["wbits"] = self._mode
215 if strategy is not None:
216 kwargs["strategy"] = strategy
217 if level is not None:
218 kwargs["level"] = level
219 self._compressor = self._zlib_backend.compressobj(**kwargs)
220
221 def compress_sync(self, data: Buffer) -> bytes:
222 return self._compressor.compress(data)
223
224 async def compress(self, data: Buffer) -> bytes:
225 """Compress the data and returned the compressed bytes.
226
227 Note that flush() must be called after the last call to compress()
228
229 If the data size is large than the max_sync_chunk_size, the compression
230 will be done in the executor. Otherwise, the compression will be done
231 in the event loop.
232
233 **WARNING: This method is NOT cancellation-safe when used with flush().**
234 If this operation is cancelled, the compressor state may be corrupted.
235 The connection MUST be closed after cancellation to avoid data corruption
236 in subsequent compress operations.
237
238 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
239 compress() + flush() + send operations in a shield and lock to ensure atomicity.
240 """
241 # For large payloads, offload compression to executor to avoid blocking event loop
242 should_use_executor = (
243 self._max_sync_chunk_size is not None
244 and len(data) > self._max_sync_chunk_size
245 )
246 if should_use_executor:
247 return await asyncio.get_running_loop().run_in_executor(
248 self._executor, self._compressor.compress, data
249 )
250 return self.compress_sync(data)
251
252 def flush(self, mode: int | None = None) -> bytes:
253 """Flush the compressor synchronously.
254
255 **WARNING: This method is NOT cancellation-safe when called after compress().**
256 The flush() operation accesses shared compressor state. If compress() was
257 cancelled, calling flush() may result in corrupted data. The connection MUST
258 be closed after compress() cancellation.
259
260 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
261 compress() + flush() + send operations in a shield and lock to ensure atomicity.
262 """
263 return self._compressor.flush(
264 mode if mode is not None else self._zlib_backend.Z_FINISH
265 )
266
267
268class ZLibDecompressor(DecompressionBaseHandler):
269 def __init__(
270 self,
271 encoding: str | None = None,
272 suppress_deflate_header: bool = False,
273 executor: Executor | None = None,
274 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
275 ):
276 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
277 self._mode = encoding_to_mode(encoding, suppress_deflate_header)
278 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
279 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
280 self._last_empty = False
281 self._pending_unused_data: bytes | None = None
282
283 def decompress_sync(
284 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
285 ) -> bytes:
286 if self._pending_unused_data is not None:
287 data = self._pending_unused_data + bytes(data)
288 self._pending_unused_data = None
289 result = self._decompressor.decompress(
290 self._decompressor.unconsumed_tail + data, max_length
291 )
292 # Only way to know that isal has no further data is checking we get no output
293 self._last_empty = result == b""
294
295 # Handle concatenated gzip/deflate streams (multi-member).
296 # After a member ends, unused_data holds the start of the next member.
297 # Create a fresh decompressor for each subsequent member.
298 while self._decompressor.eof and self._decompressor.unused_data:
299 unused = self._decompressor.unused_data
300 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
301 if max_length != ZLIB_MAX_LENGTH_UNLIMITED:
302 max_length -= len(result)
303 if max_length <= 0:
304 self._pending_unused_data = unused
305 break
306 chunk = self._decompressor.decompress(unused, max_length)
307 self._last_empty = chunk == b""
308 result += chunk
309
310 # Member ended exactly at chunk boundary — no unused_data, but the
311 # next feed_data() call would fail on the spent decompressor.
312 # Only reset for gzip; deflate's feed_eof() relies on eof=True to
313 # confirm the stream is complete.
314 if self._decompressor.eof and self._mode > self._zlib_backend.MAX_WBITS:
315 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
316
317 return result
318
319 def flush(self, length: int = 0) -> bytes:
320 return (
321 self._decompressor.flush(length)
322 if length > 0
323 else self._decompressor.flush()
324 )
325
326 @property
327 def data_available(self) -> bool:
328 return (
329 bool(self._decompressor.unconsumed_tail)
330 or not self._last_empty
331 or self._pending_unused_data is not None
332 )
333
334 @property
335 def eof(self) -> bool:
336 return self._decompressor.eof
337
338
339class BrotliDecompressor(DecompressionBaseHandler):
340 # Supports both 'brotlipy' and 'Brotli' packages
341 # since they share an import name. The top branches
342 # are for 'brotlipy' and bottom branches for 'Brotli'
343 def __init__(
344 self,
345 executor: Executor | None = None,
346 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
347 ) -> None:
348 """Decompress data using the Brotli library."""
349 if not HAS_BROTLI:
350 raise RuntimeError(
351 "The brotli decompression is not available. "
352 "Please install `Brotli` module"
353 )
354 self._obj = brotli.Decompressor()
355 self._last_empty = False
356 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
357
358 def decompress_sync(
359 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
360 ) -> bytes:
361 """Decompress the given data."""
362 if hasattr(self._obj, "decompress"):
363 if max_length == ZLIB_MAX_LENGTH_UNLIMITED:
364 result = cast(bytes, self._obj.decompress(data))
365 else:
366 result = cast(bytes, self._obj.decompress(data, max_length))
367 else:
368 if max_length == ZLIB_MAX_LENGTH_UNLIMITED:
369 result = cast(bytes, self._obj.process(data))
370 else:
371 result = cast(bytes, self._obj.process(data, max_length))
372 # Only way to know that brotli has no further data is checking we get no output
373 self._last_empty = result == b""
374 return result
375
376 def flush(self) -> bytes:
377 """Flush the decompressor."""
378 if hasattr(self._obj, "flush"):
379 return cast(bytes, self._obj.flush())
380 return b""
381
382 @property
383 def data_available(self) -> bool:
384 return not self._obj.is_finished() and not self._last_empty
385
386
387class ZSTDDecompressor(DecompressionBaseHandler):
388 def __init__(
389 self,
390 executor: Executor | None = None,
391 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
392 ) -> None:
393 if not HAS_ZSTD:
394 raise RuntimeError(
395 "The zstd decompression is not available. "
396 "Please install `backports.zstd` module"
397 )
398 self._obj = ZstdDecompressor()
399 self._pending_unused_data: bytes | None = None
400 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
401
402 def decompress_sync(
403 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
404 ) -> bytes:
405 # zstd uses -1 for unlimited, while zlib uses 0 for unlimited
406 # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
407 zstd_max_length = (
408 ZSTD_MAX_LENGTH_UNLIMITED
409 if max_length == ZLIB_MAX_LENGTH_UNLIMITED
410 else max_length
411 )
412 if self._pending_unused_data is not None:
413 data = self._pending_unused_data + data
414 self._pending_unused_data = None
415 result = self._obj.decompress(data, zstd_max_length)
416
417 # Handle multi-frame zstd streams.
418 # https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1
419 # ZstdDecompressor handles one frame only. When a frame ends,
420 # eof becomes True and any trailing data goes to unused_data.
421 # We create a fresh decompressor to continue with the next frame.
422 while self._obj.eof and self._obj.unused_data:
423 unused_data = self._obj.unused_data
424 self._obj = ZstdDecompressor()
425 if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED:
426 zstd_max_length -= len(result)
427 if zstd_max_length <= 0:
428 self._pending_unused_data = unused_data
429 break
430 result += self._obj.decompress(unused_data, zstd_max_length)
431
432 # Frame ended exactly at chunk boundary — no unused_data, but the
433 # next feed_data() call would fail on the spent decompressor.
434 # Prepare a fresh one for the next chunk.
435 if self._obj.eof:
436 self._obj = ZstdDecompressor()
437
438 return result
439
440 def flush(self) -> bytes:
441 return b""
442
443 @property
444 def data_available(self) -> bool:
445 return (
446 not self._obj.needs_input and not self._obj.eof
447 ) or self._pending_unused_data is not None