1import asyncio
2import sys
3import zlib
4from abc import ABC, abstractmethod
5from concurrent.futures import Executor
6from typing import Any, Final, Optional, Protocol, TypedDict, cast
7
8if sys.version_info >= (3, 12):
9 from collections.abc import Buffer
10else:
11 from typing import Union
12
13 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
14
15try:
16 try:
17 import brotlicffi as brotli
18 except ImportError:
19 import brotli
20
21 HAS_BROTLI = True
22except ImportError: # pragma: no cover
23 HAS_BROTLI = False
24
25try:
26 if sys.version_info >= (3, 14):
27 from compression.zstd import ZstdDecompressor # noqa: I900
28 else: # TODO(PY314): Remove mentions of backports.zstd across codebase
29 from backports.zstd import ZstdDecompressor
30
31 HAS_ZSTD = True
32except ImportError:
33 HAS_ZSTD = False
34
35
36MAX_SYNC_CHUNK_SIZE = 4096
37DEFAULT_MAX_DECOMPRESS_SIZE = 2**25 # 32MiB
38
39# Unlimited decompression constants - different libraries use different conventions
40ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
41ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
42
43
44class ZLibCompressObjProtocol(Protocol):
45 def compress(self, data: Buffer) -> bytes: ...
46 def flush(self, mode: int = ..., /) -> bytes: ...
47
48
49class ZLibDecompressObjProtocol(Protocol):
50 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
51 def flush(self, length: int = ..., /) -> bytes: ...
52
53 @property
54 def eof(self) -> bool: ...
55
56
57class ZLibBackendProtocol(Protocol):
58 MAX_WBITS: int
59 Z_FULL_FLUSH: int
60 Z_SYNC_FLUSH: int
61 Z_BEST_SPEED: int
62 Z_FINISH: int
63
64 def compressobj(
65 self,
66 level: int = ...,
67 method: int = ...,
68 wbits: int = ...,
69 memLevel: int = ...,
70 strategy: int = ...,
71 zdict: Optional[Buffer] = ...,
72 ) -> ZLibCompressObjProtocol: ...
73 def decompressobj(
74 self, wbits: int = ..., zdict: Buffer = ...
75 ) -> ZLibDecompressObjProtocol: ...
76
77 def compress(
78 self, data: Buffer, /, level: int = ..., wbits: int = ...
79 ) -> bytes: ...
80 def decompress(
81 self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
82 ) -> bytes: ...
83
84
85class CompressObjArgs(TypedDict, total=False):
86 wbits: int
87 strategy: int
88 level: int
89
90
91class ZLibBackendWrapper:
92 def __init__(self, _zlib_backend: ZLibBackendProtocol):
93 self._zlib_backend: ZLibBackendProtocol = _zlib_backend
94
95 @property
96 def name(self) -> str:
97 return getattr(self._zlib_backend, "__name__", "undefined")
98
99 @property
100 def MAX_WBITS(self) -> int:
101 return self._zlib_backend.MAX_WBITS
102
103 @property
104 def Z_FULL_FLUSH(self) -> int:
105 return self._zlib_backend.Z_FULL_FLUSH
106
107 @property
108 def Z_SYNC_FLUSH(self) -> int:
109 return self._zlib_backend.Z_SYNC_FLUSH
110
111 @property
112 def Z_BEST_SPEED(self) -> int:
113 return self._zlib_backend.Z_BEST_SPEED
114
115 @property
116 def Z_FINISH(self) -> int:
117 return self._zlib_backend.Z_FINISH
118
119 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
120 return self._zlib_backend.compressobj(*args, **kwargs)
121
122 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
123 return self._zlib_backend.decompressobj(*args, **kwargs)
124
125 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
126 return self._zlib_backend.compress(data, *args, **kwargs)
127
128 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
129 return self._zlib_backend.decompress(data, *args, **kwargs)
130
131 # Everything not explicitly listed in the Protocol we just pass through
132 def __getattr__(self, attrname: str) -> Any:
133 return getattr(self._zlib_backend, attrname)
134
135
136ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
137
138
139def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
140 ZLibBackend._zlib_backend = new_zlib_backend
141
142
143def encoding_to_mode(
144 encoding: Optional[str] = None,
145 suppress_deflate_header: bool = False,
146) -> int:
147 if encoding == "gzip":
148 return 16 + ZLibBackend.MAX_WBITS
149
150 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
151
152
153class DecompressionBaseHandler(ABC):
154 def __init__(
155 self,
156 executor: Optional[Executor] = None,
157 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
158 ):
159 """Base class for decompression handlers."""
160 self._executor = executor
161 self._max_sync_chunk_size = max_sync_chunk_size
162
163 @abstractmethod
164 def decompress_sync(
165 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
166 ) -> bytes:
167 """Decompress the given data."""
168
169 async def decompress(
170 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
171 ) -> bytes:
172 """Decompress the given data."""
173 if (
174 self._max_sync_chunk_size is not None
175 and len(data) > self._max_sync_chunk_size
176 ):
177 return await asyncio.get_event_loop().run_in_executor(
178 self._executor, self.decompress_sync, data, max_length
179 )
180 return self.decompress_sync(data, max_length)
181
182
183class ZLibCompressor:
184 def __init__(
185 self,
186 encoding: Optional[str] = None,
187 suppress_deflate_header: bool = False,
188 level: Optional[int] = None,
189 wbits: Optional[int] = None,
190 strategy: Optional[int] = None,
191 executor: Optional[Executor] = None,
192 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
193 ):
194 self._executor = executor
195 self._max_sync_chunk_size = max_sync_chunk_size
196 self._mode = (
197 encoding_to_mode(encoding, suppress_deflate_header)
198 if wbits is None
199 else wbits
200 )
201 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
202
203 kwargs: CompressObjArgs = {}
204 kwargs["wbits"] = self._mode
205 if strategy is not None:
206 kwargs["strategy"] = strategy
207 if level is not None:
208 kwargs["level"] = level
209 self._compressor = self._zlib_backend.compressobj(**kwargs)
210
211 def compress_sync(self, data: bytes) -> bytes:
212 return self._compressor.compress(data)
213
214 async def compress(self, data: bytes) -> bytes:
215 """Compress the data and returned the compressed bytes.
216
217 Note that flush() must be called after the last call to compress()
218
219 If the data size is large than the max_sync_chunk_size, the compression
220 will be done in the executor. Otherwise, the compression will be done
221 in the event loop.
222
223 **WARNING: This method is NOT cancellation-safe when used with flush().**
224 If this operation is cancelled, the compressor state may be corrupted.
225 The connection MUST be closed after cancellation to avoid data corruption
226 in subsequent compress operations.
227
228 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
229 compress() + flush() + send operations in a shield and lock to ensure atomicity.
230 """
231 # For large payloads, offload compression to executor to avoid blocking event loop
232 should_use_executor = (
233 self._max_sync_chunk_size is not None
234 and len(data) > self._max_sync_chunk_size
235 )
236 if should_use_executor:
237 return await asyncio.get_running_loop().run_in_executor(
238 self._executor, self._compressor.compress, data
239 )
240 return self.compress_sync(data)
241
242 def flush(self, mode: Optional[int] = None) -> bytes:
243 """Flush the compressor synchronously.
244
245 **WARNING: This method is NOT cancellation-safe when called after compress().**
246 The flush() operation accesses shared compressor state. If compress() was
247 cancelled, calling flush() may result in corrupted data. The connection MUST
248 be closed after compress() cancellation.
249
250 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
251 compress() + flush() + send operations in a shield and lock to ensure atomicity.
252 """
253 return self._compressor.flush(
254 mode if mode is not None else self._zlib_backend.Z_FINISH
255 )
256
257
258class ZLibDecompressor(DecompressionBaseHandler):
259 def __init__(
260 self,
261 encoding: Optional[str] = None,
262 suppress_deflate_header: bool = False,
263 executor: Optional[Executor] = None,
264 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
265 ):
266 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
267 self._mode = encoding_to_mode(encoding, suppress_deflate_header)
268 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
269 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
270
271 def decompress_sync(
272 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
273 ) -> bytes:
274 return self._decompressor.decompress(data, max_length)
275
276 def flush(self, length: int = 0) -> bytes:
277 return (
278 self._decompressor.flush(length)
279 if length > 0
280 else self._decompressor.flush()
281 )
282
283 @property
284 def eof(self) -> bool:
285 return self._decompressor.eof
286
287
288class BrotliDecompressor(DecompressionBaseHandler):
289 # Supports both 'brotlipy' and 'Brotli' packages
290 # since they share an import name. The top branches
291 # are for 'brotlipy' and bottom branches for 'Brotli'
292 def __init__(
293 self,
294 executor: Optional[Executor] = None,
295 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
296 ) -> None:
297 """Decompress data using the Brotli library."""
298 if not HAS_BROTLI:
299 raise RuntimeError(
300 "The brotli decompression is not available. "
301 "Please install `Brotli` module"
302 )
303 self._obj = brotli.Decompressor()
304 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
305
306 def decompress_sync(
307 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
308 ) -> bytes:
309 """Decompress the given data."""
310 if hasattr(self._obj, "decompress"):
311 return cast(bytes, self._obj.decompress(data, max_length))
312 return cast(bytes, self._obj.process(data, max_length))
313
314 def flush(self) -> bytes:
315 """Flush the decompressor."""
316 if hasattr(self._obj, "flush"):
317 return cast(bytes, self._obj.flush())
318 return b""
319
320
321class ZSTDDecompressor(DecompressionBaseHandler):
322 def __init__(
323 self,
324 executor: Optional[Executor] = None,
325 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
326 ) -> None:
327 if not HAS_ZSTD:
328 raise RuntimeError(
329 "The zstd decompression is not available. "
330 "Please install `backports.zstd` module"
331 )
332 self._obj = ZstdDecompressor()
333 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
334
335 def decompress_sync(
336 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
337 ) -> bytes:
338 # zstd uses -1 for unlimited, while zlib uses 0 for unlimited
339 # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
340 zstd_max_length = (
341 ZSTD_MAX_LENGTH_UNLIMITED
342 if max_length == ZLIB_MAX_LENGTH_UNLIMITED
343 else max_length
344 )
345 return self._obj.decompress(data, zstd_max_length)
346
347 def flush(self) -> bytes:
348 return b""