1import asyncio
2import sys
3import zlib
4from concurrent.futures import Executor
5from typing import Any, Final, Optional, Protocol, TypedDict, cast
6
7if sys.version_info >= (3, 12):
8 from collections.abc import Buffer
9else:
10 from typing import Union
11
12 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
13
14try:
15 try:
16 import brotlicffi as brotli
17 except ImportError:
18 import brotli
19
20 HAS_BROTLI = True
21except ImportError: # pragma: no cover
22 HAS_BROTLI = False
23
24try:
25 if sys.version_info >= (3, 14):
26 from compression.zstd import ZstdDecompressor # noqa: I900
27 else: # TODO(PY314): Remove mentions of backports.zstd across codebase
28 from backports.zstd import ZstdDecompressor
29
30 HAS_ZSTD = True
31except ImportError:
32 HAS_ZSTD = False
33
34
35MAX_SYNC_CHUNK_SIZE = 1024
36
37
38class ZLibCompressObjProtocol(Protocol):
39 def compress(self, data: Buffer) -> bytes: ...
40 def flush(self, mode: int = ..., /) -> bytes: ...
41
42
43class ZLibDecompressObjProtocol(Protocol):
44 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
45 def flush(self, length: int = ..., /) -> bytes: ...
46
47 @property
48 def eof(self) -> bool: ...
49
50
51class ZLibBackendProtocol(Protocol):
52 MAX_WBITS: int
53 Z_FULL_FLUSH: int
54 Z_SYNC_FLUSH: int
55 Z_BEST_SPEED: int
56 Z_FINISH: int
57
58 def compressobj(
59 self,
60 level: int = ...,
61 method: int = ...,
62 wbits: int = ...,
63 memLevel: int = ...,
64 strategy: int = ...,
65 zdict: Optional[Buffer] = ...,
66 ) -> ZLibCompressObjProtocol: ...
67 def decompressobj(
68 self, wbits: int = ..., zdict: Buffer = ...
69 ) -> ZLibDecompressObjProtocol: ...
70
71 def compress(
72 self, data: Buffer, /, level: int = ..., wbits: int = ...
73 ) -> bytes: ...
74 def decompress(
75 self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
76 ) -> bytes: ...
77
78
79class CompressObjArgs(TypedDict, total=False):
80 wbits: int
81 strategy: int
82 level: int
83
84
85class ZLibBackendWrapper:
86 def __init__(self, _zlib_backend: ZLibBackendProtocol):
87 self._zlib_backend: ZLibBackendProtocol = _zlib_backend
88
89 @property
90 def name(self) -> str:
91 return getattr(self._zlib_backend, "__name__", "undefined")
92
93 @property
94 def MAX_WBITS(self) -> int:
95 return self._zlib_backend.MAX_WBITS
96
97 @property
98 def Z_FULL_FLUSH(self) -> int:
99 return self._zlib_backend.Z_FULL_FLUSH
100
101 @property
102 def Z_SYNC_FLUSH(self) -> int:
103 return self._zlib_backend.Z_SYNC_FLUSH
104
105 @property
106 def Z_BEST_SPEED(self) -> int:
107 return self._zlib_backend.Z_BEST_SPEED
108
109 @property
110 def Z_FINISH(self) -> int:
111 return self._zlib_backend.Z_FINISH
112
113 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
114 return self._zlib_backend.compressobj(*args, **kwargs)
115
116 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
117 return self._zlib_backend.decompressobj(*args, **kwargs)
118
119 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
120 return self._zlib_backend.compress(data, *args, **kwargs)
121
122 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
123 return self._zlib_backend.decompress(data, *args, **kwargs)
124
125 # Everything not explicitly listed in the Protocol we just pass through
126 def __getattr__(self, attrname: str) -> Any:
127 return getattr(self._zlib_backend, attrname)
128
129
130ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
131
132
133def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
134 ZLibBackend._zlib_backend = new_zlib_backend
135
136
137def encoding_to_mode(
138 encoding: Optional[str] = None,
139 suppress_deflate_header: bool = False,
140) -> int:
141 if encoding == "gzip":
142 return 16 + ZLibBackend.MAX_WBITS
143
144 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
145
146
147class ZlibBaseHandler:
148 def __init__(
149 self,
150 mode: int,
151 executor: Optional[Executor] = None,
152 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
153 ):
154 self._mode = mode
155 self._executor = executor
156 self._max_sync_chunk_size = max_sync_chunk_size
157
158
159class ZLibCompressor(ZlibBaseHandler):
160 def __init__(
161 self,
162 encoding: Optional[str] = None,
163 suppress_deflate_header: bool = False,
164 level: Optional[int] = None,
165 wbits: Optional[int] = None,
166 strategy: Optional[int] = None,
167 executor: Optional[Executor] = None,
168 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
169 ):
170 super().__init__(
171 mode=(
172 encoding_to_mode(encoding, suppress_deflate_header)
173 if wbits is None
174 else wbits
175 ),
176 executor=executor,
177 max_sync_chunk_size=max_sync_chunk_size,
178 )
179 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
180
181 kwargs: CompressObjArgs = {}
182 kwargs["wbits"] = self._mode
183 if strategy is not None:
184 kwargs["strategy"] = strategy
185 if level is not None:
186 kwargs["level"] = level
187 self._compressor = self._zlib_backend.compressobj(**kwargs)
188
189 def compress_sync(self, data: bytes) -> bytes:
190 return self._compressor.compress(data)
191
192 async def compress(self, data: bytes) -> bytes:
193 """Compress the data and returned the compressed bytes.
194
195 Note that flush() must be called after the last call to compress()
196
197 If the data size is large than the max_sync_chunk_size, the compression
198 will be done in the executor. Otherwise, the compression will be done
199 in the event loop.
200
201 **WARNING: This method is NOT cancellation-safe when used with flush().**
202 If this operation is cancelled, the compressor state may be corrupted.
203 The connection MUST be closed after cancellation to avoid data corruption
204 in subsequent compress operations.
205
206 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
207 compress() + flush() + send operations in a shield and lock to ensure atomicity.
208 """
209 # For large payloads, offload compression to executor to avoid blocking event loop
210 should_use_executor = (
211 self._max_sync_chunk_size is not None
212 and len(data) > self._max_sync_chunk_size
213 )
214 if should_use_executor:
215 return await asyncio.get_running_loop().run_in_executor(
216 self._executor, self._compressor.compress, data
217 )
218 return self.compress_sync(data)
219
220 def flush(self, mode: Optional[int] = None) -> bytes:
221 """Flush the compressor synchronously.
222
223 **WARNING: This method is NOT cancellation-safe when called after compress().**
224 The flush() operation accesses shared compressor state. If compress() was
225 cancelled, calling flush() may result in corrupted data. The connection MUST
226 be closed after compress() cancellation.
227
228 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
229 compress() + flush() + send operations in a shield and lock to ensure atomicity.
230 """
231 return self._compressor.flush(
232 mode if mode is not None else self._zlib_backend.Z_FINISH
233 )
234
235
236class ZLibDecompressor(ZlibBaseHandler):
237 def __init__(
238 self,
239 encoding: Optional[str] = None,
240 suppress_deflate_header: bool = False,
241 executor: Optional[Executor] = None,
242 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
243 ):
244 super().__init__(
245 mode=encoding_to_mode(encoding, suppress_deflate_header),
246 executor=executor,
247 max_sync_chunk_size=max_sync_chunk_size,
248 )
249 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
250 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
251
252 def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
253 return self._decompressor.decompress(data, max_length)
254
255 async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
256 """Decompress the data and return the decompressed bytes.
257
258 If the data size is large than the max_sync_chunk_size, the decompression
259 will be done in the executor. Otherwise, the decompression will be done
260 in the event loop.
261 """
262 if (
263 self._max_sync_chunk_size is not None
264 and len(data) > self._max_sync_chunk_size
265 ):
266 return await asyncio.get_running_loop().run_in_executor(
267 self._executor, self._decompressor.decompress, data, max_length
268 )
269 return self.decompress_sync(data, max_length)
270
271 def flush(self, length: int = 0) -> bytes:
272 return (
273 self._decompressor.flush(length)
274 if length > 0
275 else self._decompressor.flush()
276 )
277
278 @property
279 def eof(self) -> bool:
280 return self._decompressor.eof
281
282
283class BrotliDecompressor:
284 # Supports both 'brotlipy' and 'Brotli' packages
285 # since they share an import name. The top branches
286 # are for 'brotlipy' and bottom branches for 'Brotli'
287 def __init__(self) -> None:
288 if not HAS_BROTLI:
289 raise RuntimeError(
290 "The brotli decompression is not available. "
291 "Please install `Brotli` module"
292 )
293 self._obj = brotli.Decompressor()
294
295 def decompress_sync(self, data: bytes) -> bytes:
296 if hasattr(self._obj, "decompress"):
297 return cast(bytes, self._obj.decompress(data))
298 return cast(bytes, self._obj.process(data))
299
300 def flush(self) -> bytes:
301 if hasattr(self._obj, "flush"):
302 return cast(bytes, self._obj.flush())
303 return b""
304
305
306class ZSTDDecompressor:
307 def __init__(self) -> None:
308 if not HAS_ZSTD:
309 raise RuntimeError(
310 "The zstd decompression is not available. "
311 "Please install `backports.zstd` module"
312 )
313 self._obj = ZstdDecompressor()
314
315 def decompress_sync(self, data: bytes) -> bytes:
316 return self._obj.decompress(data)
317
318 def flush(self) -> bytes:
319 return b""