1import asyncio
2import sys
3import zlib
4from concurrent.futures import Executor
5from typing import Any, Final, Optional, Protocol, TypedDict, cast
6
7if sys.version_info >= (3, 12):
8 from collections.abc import Buffer
9else:
10 from typing import Union
11
12 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
13
14try:
15 try:
16 import brotlicffi as brotli
17 except ImportError:
18 import brotli
19
20 HAS_BROTLI = True
21except ImportError: # pragma: no cover
22 HAS_BROTLI = False
23
24if sys.version_info >= (3, 14):
25 import compression.zstd # noqa: I900
26
27 HAS_ZSTD = True
28else:
29 try:
30 import zstandard
31
32 HAS_ZSTD = True
33 except ImportError:
34 HAS_ZSTD = False
35
36MAX_SYNC_CHUNK_SIZE = 1024
37
38
39class ZLibCompressObjProtocol(Protocol):
40 def compress(self, data: Buffer) -> bytes: ...
41 def flush(self, mode: int = ..., /) -> bytes: ...
42
43
44class ZLibDecompressObjProtocol(Protocol):
45 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
46 def flush(self, length: int = ..., /) -> bytes: ...
47
48 @property
49 def eof(self) -> bool: ...
50
51
52class ZLibBackendProtocol(Protocol):
53 MAX_WBITS: int
54 Z_FULL_FLUSH: int
55 Z_SYNC_FLUSH: int
56 Z_BEST_SPEED: int
57 Z_FINISH: int
58
59 def compressobj(
60 self,
61 level: int = ...,
62 method: int = ...,
63 wbits: int = ...,
64 memLevel: int = ...,
65 strategy: int = ...,
66 zdict: Optional[Buffer] = ...,
67 ) -> ZLibCompressObjProtocol: ...
68 def decompressobj(
69 self, wbits: int = ..., zdict: Buffer = ...
70 ) -> ZLibDecompressObjProtocol: ...
71
72 def compress(
73 self, data: Buffer, /, level: int = ..., wbits: int = ...
74 ) -> bytes: ...
75 def decompress(
76 self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
77 ) -> bytes: ...
78
79
80class CompressObjArgs(TypedDict, total=False):
81 wbits: int
82 strategy: int
83 level: int
84
85
86class ZLibBackendWrapper:
87 def __init__(self, _zlib_backend: ZLibBackendProtocol):
88 self._zlib_backend: ZLibBackendProtocol = _zlib_backend
89
90 @property
91 def name(self) -> str:
92 return getattr(self._zlib_backend, "__name__", "undefined")
93
94 @property
95 def MAX_WBITS(self) -> int:
96 return self._zlib_backend.MAX_WBITS
97
98 @property
99 def Z_FULL_FLUSH(self) -> int:
100 return self._zlib_backend.Z_FULL_FLUSH
101
102 @property
103 def Z_SYNC_FLUSH(self) -> int:
104 return self._zlib_backend.Z_SYNC_FLUSH
105
106 @property
107 def Z_BEST_SPEED(self) -> int:
108 return self._zlib_backend.Z_BEST_SPEED
109
110 @property
111 def Z_FINISH(self) -> int:
112 return self._zlib_backend.Z_FINISH
113
114 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
115 return self._zlib_backend.compressobj(*args, **kwargs)
116
117 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
118 return self._zlib_backend.decompressobj(*args, **kwargs)
119
120 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
121 return self._zlib_backend.compress(data, *args, **kwargs)
122
123 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
124 return self._zlib_backend.decompress(data, *args, **kwargs)
125
126 # Everything not explicitly listed in the Protocol we just pass through
127 def __getattr__(self, attrname: str) -> Any:
128 return getattr(self._zlib_backend, attrname)
129
130
131ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
132
133
134def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
135 ZLibBackend._zlib_backend = new_zlib_backend
136
137
138def encoding_to_mode(
139 encoding: Optional[str] = None,
140 suppress_deflate_header: bool = False,
141) -> int:
142 if encoding == "gzip":
143 return 16 + ZLibBackend.MAX_WBITS
144
145 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
146
147
148class ZlibBaseHandler:
149 def __init__(
150 self,
151 mode: int,
152 executor: Optional[Executor] = None,
153 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
154 ):
155 self._mode = mode
156 self._executor = executor
157 self._max_sync_chunk_size = max_sync_chunk_size
158
159
160class ZLibCompressor(ZlibBaseHandler):
161 def __init__(
162 self,
163 encoding: Optional[str] = None,
164 suppress_deflate_header: bool = False,
165 level: Optional[int] = None,
166 wbits: Optional[int] = None,
167 strategy: Optional[int] = None,
168 executor: Optional[Executor] = None,
169 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
170 ):
171 super().__init__(
172 mode=(
173 encoding_to_mode(encoding, suppress_deflate_header)
174 if wbits is None
175 else wbits
176 ),
177 executor=executor,
178 max_sync_chunk_size=max_sync_chunk_size,
179 )
180 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
181
182 kwargs: CompressObjArgs = {}
183 kwargs["wbits"] = self._mode
184 if strategy is not None:
185 kwargs["strategy"] = strategy
186 if level is not None:
187 kwargs["level"] = level
188 self._compressor = self._zlib_backend.compressobj(**kwargs)
189 self._compress_lock = asyncio.Lock()
190
191 def compress_sync(self, data: bytes) -> bytes:
192 return self._compressor.compress(data)
193
194 async def compress(self, data: bytes) -> bytes:
195 """Compress the data and returned the compressed bytes.
196
197 Note that flush() must be called after the last call to compress()
198
199 If the data size is large than the max_sync_chunk_size, the compression
200 will be done in the executor. Otherwise, the compression will be done
201 in the event loop.
202 """
203 async with self._compress_lock:
204 # To ensure the stream is consistent in the event
205 # there are multiple writers, we need to lock
206 # the compressor so that only one writer can
207 # compress at a time.
208 if (
209 self._max_sync_chunk_size is not None
210 and len(data) > self._max_sync_chunk_size
211 ):
212 return await asyncio.get_running_loop().run_in_executor(
213 self._executor, self._compressor.compress, data
214 )
215 return self.compress_sync(data)
216
217 def flush(self, mode: Optional[int] = None) -> bytes:
218 return self._compressor.flush(
219 mode if mode is not None else self._zlib_backend.Z_FINISH
220 )
221
222
223class ZLibDecompressor(ZlibBaseHandler):
224 def __init__(
225 self,
226 encoding: Optional[str] = None,
227 suppress_deflate_header: bool = False,
228 executor: Optional[Executor] = None,
229 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
230 ):
231 super().__init__(
232 mode=encoding_to_mode(encoding, suppress_deflate_header),
233 executor=executor,
234 max_sync_chunk_size=max_sync_chunk_size,
235 )
236 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
237 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
238
239 def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
240 return self._decompressor.decompress(data, max_length)
241
242 async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
243 """Decompress the data and return the decompressed bytes.
244
245 If the data size is large than the max_sync_chunk_size, the decompression
246 will be done in the executor. Otherwise, the decompression will be done
247 in the event loop.
248 """
249 if (
250 self._max_sync_chunk_size is not None
251 and len(data) > self._max_sync_chunk_size
252 ):
253 return await asyncio.get_running_loop().run_in_executor(
254 self._executor, self._decompressor.decompress, data, max_length
255 )
256 return self.decompress_sync(data, max_length)
257
258 def flush(self, length: int = 0) -> bytes:
259 return (
260 self._decompressor.flush(length)
261 if length > 0
262 else self._decompressor.flush()
263 )
264
265 @property
266 def eof(self) -> bool:
267 return self._decompressor.eof
268
269
270class BrotliDecompressor:
271 # Supports both 'brotlipy' and 'Brotli' packages
272 # since they share an import name. The top branches
273 # are for 'brotlipy' and bottom branches for 'Brotli'
274 def __init__(self) -> None:
275 if not HAS_BROTLI:
276 raise RuntimeError(
277 "The brotli decompression is not available. "
278 "Please install `Brotli` module"
279 )
280 self._obj = brotli.Decompressor()
281
282 def decompress_sync(self, data: bytes) -> bytes:
283 if hasattr(self._obj, "decompress"):
284 return cast(bytes, self._obj.decompress(data))
285 return cast(bytes, self._obj.process(data))
286
287 def flush(self) -> bytes:
288 if hasattr(self._obj, "flush"):
289 return cast(bytes, self._obj.flush())
290 return b""
291
292
293class ZSTDDecompressor:
294 def __init__(self) -> None:
295 if not HAS_ZSTD:
296 raise RuntimeError(
297 "The zstd decompression is not available. "
298 "Please install `zstandard` module"
299 )
300 if sys.version_info >= (3, 14):
301 self._obj = compression.zstd.ZstdDecompressor()
302 else:
303 self._obj = zstandard.ZstdDecompressor()
304
305 def decompress_sync(self, data: bytes) -> bytes:
306 return self._obj.decompress(data)
307
308 def flush(self) -> bytes:
309 return b""