1import asyncio 
    2import sys 
    3import zlib 
    4from concurrent.futures import Executor 
    5from typing import Any, Final, Optional, Protocol, TypedDict, cast 
    6 
    7if sys.version_info >= (3, 12): 
    8    from collections.abc import Buffer 
    9else: 
    10    from typing import Union 
    11 
    12    Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] 
    13 
    14try: 
    15    try: 
    16        import brotlicffi as brotli 
    17    except ImportError: 
    18        import brotli 
    19 
    20    HAS_BROTLI = True 
    21except ImportError:  # pragma: no cover 
    22    HAS_BROTLI = False 
    23 
    24try: 
    25    if sys.version_info >= (3, 14): 
    26        from compression.zstd import ZstdDecompressor  # noqa: I900 
    27    else:  # TODO(PY314): Remove mentions of backports.zstd across codebase 
    28        from backports.zstd import ZstdDecompressor 
    29 
    30    HAS_ZSTD = True 
    31except ImportError: 
    32    HAS_ZSTD = False 
    33 
    34 
    35MAX_SYNC_CHUNK_SIZE = 1024 
    36 
    37 
    38class ZLibCompressObjProtocol(Protocol): 
    39    def compress(self, data: Buffer) -> bytes: ... 
    40    def flush(self, mode: int = ..., /) -> bytes: ... 
    41 
    42 
    43class ZLibDecompressObjProtocol(Protocol): 
    44    def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ... 
    45    def flush(self, length: int = ..., /) -> bytes: ... 
    46 
    47    @property 
    48    def eof(self) -> bool: ... 
    49 
    50 
    51class ZLibBackendProtocol(Protocol): 
    52    MAX_WBITS: int 
    53    Z_FULL_FLUSH: int 
    54    Z_SYNC_FLUSH: int 
    55    Z_BEST_SPEED: int 
    56    Z_FINISH: int 
    57 
    58    def compressobj( 
    59        self, 
    60        level: int = ..., 
    61        method: int = ..., 
    62        wbits: int = ..., 
    63        memLevel: int = ..., 
    64        strategy: int = ..., 
    65        zdict: Optional[Buffer] = ..., 
    66    ) -> ZLibCompressObjProtocol: ... 
    67    def decompressobj( 
    68        self, wbits: int = ..., zdict: Buffer = ... 
    69    ) -> ZLibDecompressObjProtocol: ... 
    70 
    71    def compress( 
    72        self, data: Buffer, /, level: int = ..., wbits: int = ... 
    73    ) -> bytes: ... 
    74    def decompress( 
    75        self, data: Buffer, /, wbits: int = ..., bufsize: int = ... 
    76    ) -> bytes: ... 
    77 
    78 
    79class CompressObjArgs(TypedDict, total=False): 
    80    wbits: int 
    81    strategy: int 
    82    level: int 
    83 
    84 
    85class ZLibBackendWrapper: 
    86    def __init__(self, _zlib_backend: ZLibBackendProtocol): 
    87        self._zlib_backend: ZLibBackendProtocol = _zlib_backend 
    88 
    89    @property 
    90    def name(self) -> str: 
    91        return getattr(self._zlib_backend, "__name__", "undefined") 
    92 
    93    @property 
    94    def MAX_WBITS(self) -> int: 
    95        return self._zlib_backend.MAX_WBITS 
    96 
    97    @property 
    98    def Z_FULL_FLUSH(self) -> int: 
    99        return self._zlib_backend.Z_FULL_FLUSH 
    100 
    101    @property 
    102    def Z_SYNC_FLUSH(self) -> int: 
    103        return self._zlib_backend.Z_SYNC_FLUSH 
    104 
    105    @property 
    106    def Z_BEST_SPEED(self) -> int: 
    107        return self._zlib_backend.Z_BEST_SPEED 
    108 
    109    @property 
    110    def Z_FINISH(self) -> int: 
    111        return self._zlib_backend.Z_FINISH 
    112 
    113    def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol: 
    114        return self._zlib_backend.compressobj(*args, **kwargs) 
    115 
    116    def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol: 
    117        return self._zlib_backend.decompressobj(*args, **kwargs) 
    118 
    119    def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: 
    120        return self._zlib_backend.compress(data, *args, **kwargs) 
    121 
    122    def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: 
    123        return self._zlib_backend.decompress(data, *args, **kwargs) 
    124 
    125    # Everything not explicitly listed in the Protocol we just pass through 
    126    def __getattr__(self, attrname: str) -> Any: 
    127        return getattr(self._zlib_backend, attrname) 
    128 
    129 
    130ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib) 
    131 
    132 
    133def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: 
    134    ZLibBackend._zlib_backend = new_zlib_backend 
    135 
    136 
    137def encoding_to_mode( 
    138    encoding: Optional[str] = None, 
    139    suppress_deflate_header: bool = False, 
    140) -> int: 
    141    if encoding == "gzip": 
    142        return 16 + ZLibBackend.MAX_WBITS 
    143 
    144    return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS 
    145 
    146 
    147class ZlibBaseHandler: 
    148    def __init__( 
    149        self, 
    150        mode: int, 
    151        executor: Optional[Executor] = None, 
    152        max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, 
    153    ): 
    154        self._mode = mode 
    155        self._executor = executor 
    156        self._max_sync_chunk_size = max_sync_chunk_size 
    157 
    158 
    159class ZLibCompressor(ZlibBaseHandler): 
    160    def __init__( 
    161        self, 
    162        encoding: Optional[str] = None, 
    163        suppress_deflate_header: bool = False, 
    164        level: Optional[int] = None, 
    165        wbits: Optional[int] = None, 
    166        strategy: Optional[int] = None, 
    167        executor: Optional[Executor] = None, 
    168        max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, 
    169    ): 
    170        super().__init__( 
    171            mode=( 
    172                encoding_to_mode(encoding, suppress_deflate_header) 
    173                if wbits is None 
    174                else wbits 
    175            ), 
    176            executor=executor, 
    177            max_sync_chunk_size=max_sync_chunk_size, 
    178        ) 
    179        self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) 
    180 
    181        kwargs: CompressObjArgs = {} 
    182        kwargs["wbits"] = self._mode 
    183        if strategy is not None: 
    184            kwargs["strategy"] = strategy 
    185        if level is not None: 
    186            kwargs["level"] = level 
    187        self._compressor = self._zlib_backend.compressobj(**kwargs) 
    188        self._compress_lock = asyncio.Lock() 
    189 
    190    def compress_sync(self, data: bytes) -> bytes: 
    191        return self._compressor.compress(data) 
    192 
    193    async def compress(self, data: bytes) -> bytes: 
    194        """Compress the data and returned the compressed bytes. 
    195 
    196        Note that flush() must be called after the last call to compress() 
    197 
    198        If the data size is large than the max_sync_chunk_size, the compression 
    199        will be done in the executor. Otherwise, the compression will be done 
    200        in the event loop. 
    201        """ 
    202        async with self._compress_lock: 
    203            # To ensure the stream is consistent in the event 
    204            # there are multiple writers, we need to lock 
    205            # the compressor so that only one writer can 
    206            # compress at a time. 
    207            if ( 
    208                self._max_sync_chunk_size is not None 
    209                and len(data) > self._max_sync_chunk_size 
    210            ): 
    211                return await asyncio.get_running_loop().run_in_executor( 
    212                    self._executor, self._compressor.compress, data 
    213                ) 
    214            return self.compress_sync(data) 
    215 
    216    def flush(self, mode: Optional[int] = None) -> bytes: 
    217        return self._compressor.flush( 
    218            mode if mode is not None else self._zlib_backend.Z_FINISH 
    219        ) 
    220 
    221 
    222class ZLibDecompressor(ZlibBaseHandler): 
    223    def __init__( 
    224        self, 
    225        encoding: Optional[str] = None, 
    226        suppress_deflate_header: bool = False, 
    227        executor: Optional[Executor] = None, 
    228        max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, 
    229    ): 
    230        super().__init__( 
    231            mode=encoding_to_mode(encoding, suppress_deflate_header), 
    232            executor=executor, 
    233            max_sync_chunk_size=max_sync_chunk_size, 
    234        ) 
    235        self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) 
    236        self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) 
    237 
    238    def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes: 
    239        return self._decompressor.decompress(data, max_length) 
    240 
    241    async def decompress(self, data: bytes, max_length: int = 0) -> bytes: 
    242        """Decompress the data and return the decompressed bytes. 
    243 
    244        If the data size is large than the max_sync_chunk_size, the decompression 
    245        will be done in the executor. Otherwise, the decompression will be done 
    246        in the event loop. 
    247        """ 
    248        if ( 
    249            self._max_sync_chunk_size is not None 
    250            and len(data) > self._max_sync_chunk_size 
    251        ): 
    252            return await asyncio.get_running_loop().run_in_executor( 
    253                self._executor, self._decompressor.decompress, data, max_length 
    254            ) 
    255        return self.decompress_sync(data, max_length) 
    256 
    257    def flush(self, length: int = 0) -> bytes: 
    258        return ( 
    259            self._decompressor.flush(length) 
    260            if length > 0 
    261            else self._decompressor.flush() 
    262        ) 
    263 
    264    @property 
    265    def eof(self) -> bool: 
    266        return self._decompressor.eof 
    267 
    268 
    269class BrotliDecompressor: 
    270    # Supports both 'brotlipy' and 'Brotli' packages 
    271    # since they share an import name. The top branches 
    272    # are for 'brotlipy' and bottom branches for 'Brotli' 
    273    def __init__(self) -> None: 
    274        if not HAS_BROTLI: 
    275            raise RuntimeError( 
    276                "The brotli decompression is not available. " 
    277                "Please install `Brotli` module" 
    278            ) 
    279        self._obj = brotli.Decompressor() 
    280 
    281    def decompress_sync(self, data: bytes) -> bytes: 
    282        if hasattr(self._obj, "decompress"): 
    283            return cast(bytes, self._obj.decompress(data)) 
    284        return cast(bytes, self._obj.process(data)) 
    285 
    286    def flush(self) -> bytes: 
    287        if hasattr(self._obj, "flush"): 
    288            return cast(bytes, self._obj.flush()) 
    289        return b"" 
    290 
    291 
    292class ZSTDDecompressor: 
    293    def __init__(self) -> None: 
    294        if not HAS_ZSTD: 
    295            raise RuntimeError( 
    296                "The zstd decompression is not available. " 
    297                "Please install `backports.zstd` module" 
    298            ) 
    299        self._obj = ZstdDecompressor() 
    300 
    301    def decompress_sync(self, data: bytes) -> bytes: 
    302        return self._obj.decompress(data) 
    303 
    304    def flush(self) -> bytes: 
    305        return b""