1import asyncio
2import sys
3import zlib
4from concurrent.futures import Executor
5from typing import Any, Final, Optional, Protocol, TypedDict, cast
6
7if sys.version_info >= (3, 12):
8 from collections.abc import Buffer
9else:
10 from typing import Union
11
12 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
13
14try:
15 try:
16 import brotlicffi as brotli
17 except ImportError:
18 import brotli
19
20 HAS_BROTLI = True
21except ImportError: # pragma: no cover
22 HAS_BROTLI = False
23
24MAX_SYNC_CHUNK_SIZE = 1024
25
26
27class ZLibCompressObjProtocol(Protocol):
28 def compress(self, data: Buffer) -> bytes: ...
29 def flush(self, mode: int = ..., /) -> bytes: ...
30
31
32class ZLibDecompressObjProtocol(Protocol):
33 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
34 def flush(self, length: int = ..., /) -> bytes: ...
35
36 @property
37 def eof(self) -> bool: ...
38
39
40class ZLibBackendProtocol(Protocol):
41 MAX_WBITS: int
42 Z_FULL_FLUSH: int
43 Z_SYNC_FLUSH: int
44 Z_BEST_SPEED: int
45 Z_FINISH: int
46
47 def compressobj(
48 self,
49 level: int = ...,
50 method: int = ...,
51 wbits: int = ...,
52 memLevel: int = ...,
53 strategy: int = ...,
54 zdict: Optional[Buffer] = ...,
55 ) -> ZLibCompressObjProtocol: ...
56 def decompressobj(
57 self, wbits: int = ..., zdict: Buffer = ...
58 ) -> ZLibDecompressObjProtocol: ...
59
60 def compress(
61 self, data: Buffer, /, level: int = ..., wbits: int = ...
62 ) -> bytes: ...
63 def decompress(
64 self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
65 ) -> bytes: ...
66
67
68class CompressObjArgs(TypedDict, total=False):
69 wbits: int
70 strategy: int
71 level: int
72
73
74class ZLibBackendWrapper:
75 def __init__(self, _zlib_backend: ZLibBackendProtocol):
76 self._zlib_backend: ZLibBackendProtocol = _zlib_backend
77
78 @property
79 def name(self) -> str:
80 return getattr(self._zlib_backend, "__name__", "undefined")
81
82 @property
83 def MAX_WBITS(self) -> int:
84 return self._zlib_backend.MAX_WBITS
85
86 @property
87 def Z_FULL_FLUSH(self) -> int:
88 return self._zlib_backend.Z_FULL_FLUSH
89
90 @property
91 def Z_SYNC_FLUSH(self) -> int:
92 return self._zlib_backend.Z_SYNC_FLUSH
93
94 @property
95 def Z_BEST_SPEED(self) -> int:
96 return self._zlib_backend.Z_BEST_SPEED
97
98 @property
99 def Z_FINISH(self) -> int:
100 return self._zlib_backend.Z_FINISH
101
102 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
103 return self._zlib_backend.compressobj(*args, **kwargs)
104
105 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
106 return self._zlib_backend.decompressobj(*args, **kwargs)
107
108 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
109 return self._zlib_backend.compress(data, *args, **kwargs)
110
111 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
112 return self._zlib_backend.decompress(data, *args, **kwargs)
113
114 # Everything not explicitly listed in the Protocol we just pass through
115 def __getattr__(self, attrname: str) -> Any:
116 return getattr(self._zlib_backend, attrname)
117
118
119ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
120
121
122def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
123 ZLibBackend._zlib_backend = new_zlib_backend
124
125
126def encoding_to_mode(
127 encoding: Optional[str] = None,
128 suppress_deflate_header: bool = False,
129) -> int:
130 if encoding == "gzip":
131 return 16 + ZLibBackend.MAX_WBITS
132
133 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
134
135
136class ZlibBaseHandler:
137 def __init__(
138 self,
139 mode: int,
140 executor: Optional[Executor] = None,
141 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
142 ):
143 self._mode = mode
144 self._executor = executor
145 self._max_sync_chunk_size = max_sync_chunk_size
146
147
148class ZLibCompressor(ZlibBaseHandler):
149 def __init__(
150 self,
151 encoding: Optional[str] = None,
152 suppress_deflate_header: bool = False,
153 level: Optional[int] = None,
154 wbits: Optional[int] = None,
155 strategy: Optional[int] = None,
156 executor: Optional[Executor] = None,
157 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
158 ):
159 super().__init__(
160 mode=(
161 encoding_to_mode(encoding, suppress_deflate_header)
162 if wbits is None
163 else wbits
164 ),
165 executor=executor,
166 max_sync_chunk_size=max_sync_chunk_size,
167 )
168 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
169
170 kwargs: CompressObjArgs = {}
171 kwargs["wbits"] = self._mode
172 if strategy is not None:
173 kwargs["strategy"] = strategy
174 if level is not None:
175 kwargs["level"] = level
176 self._compressor = self._zlib_backend.compressobj(**kwargs)
177 self._compress_lock = asyncio.Lock()
178
179 def compress_sync(self, data: bytes) -> bytes:
180 return self._compressor.compress(data)
181
182 async def compress(self, data: bytes) -> bytes:
183 """Compress the data and returned the compressed bytes.
184
185 Note that flush() must be called after the last call to compress()
186
187 If the data size is large than the max_sync_chunk_size, the compression
188 will be done in the executor. Otherwise, the compression will be done
189 in the event loop.
190 """
191 async with self._compress_lock:
192 # To ensure the stream is consistent in the event
193 # there are multiple writers, we need to lock
194 # the compressor so that only one writer can
195 # compress at a time.
196 if (
197 self._max_sync_chunk_size is not None
198 and len(data) > self._max_sync_chunk_size
199 ):
200 return await asyncio.get_running_loop().run_in_executor(
201 self._executor, self._compressor.compress, data
202 )
203 return self.compress_sync(data)
204
205 def flush(self, mode: Optional[int] = None) -> bytes:
206 return self._compressor.flush(
207 mode if mode is not None else self._zlib_backend.Z_FINISH
208 )
209
210
211class ZLibDecompressor(ZlibBaseHandler):
212 def __init__(
213 self,
214 encoding: Optional[str] = None,
215 suppress_deflate_header: bool = False,
216 executor: Optional[Executor] = None,
217 max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
218 ):
219 super().__init__(
220 mode=encoding_to_mode(encoding, suppress_deflate_header),
221 executor=executor,
222 max_sync_chunk_size=max_sync_chunk_size,
223 )
224 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
225 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
226
227 def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
228 return self._decompressor.decompress(data, max_length)
229
230 async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
231 """Decompress the data and return the decompressed bytes.
232
233 If the data size is large than the max_sync_chunk_size, the decompression
234 will be done in the executor. Otherwise, the decompression will be done
235 in the event loop.
236 """
237 if (
238 self._max_sync_chunk_size is not None
239 and len(data) > self._max_sync_chunk_size
240 ):
241 return await asyncio.get_running_loop().run_in_executor(
242 self._executor, self._decompressor.decompress, data, max_length
243 )
244 return self.decompress_sync(data, max_length)
245
246 def flush(self, length: int = 0) -> bytes:
247 return (
248 self._decompressor.flush(length)
249 if length > 0
250 else self._decompressor.flush()
251 )
252
253 @property
254 def eof(self) -> bool:
255 return self._decompressor.eof
256
257
258class BrotliDecompressor:
259 # Supports both 'brotlipy' and 'Brotli' packages
260 # since they share an import name. The top branches
261 # are for 'brotlipy' and bottom branches for 'Brotli'
262 def __init__(self) -> None:
263 if not HAS_BROTLI:
264 raise RuntimeError(
265 "The brotli decompression is not available. "
266 "Please install `Brotli` module"
267 )
268 self._obj = brotli.Decompressor()
269
270 def decompress_sync(self, data: bytes) -> bytes:
271 if hasattr(self._obj, "decompress"):
272 return cast(bytes, self._obj.decompress(data))
273 return cast(bytes, self._obj.process(data))
274
275 def flush(self) -> bytes:
276 if hasattr(self._obj, "flush"):
277 return cast(bytes, self._obj.flush())
278 return b""