Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/compression_utils.py: 67%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

198 statements  

1import asyncio 

2import sys 

3import zlib 

4from abc import ABC, abstractmethod 

5from concurrent.futures import Executor 

6from typing import Any, Final, Protocol, TypedDict, cast 

7 

8if sys.version_info >= (3, 12): 

9 from collections.abc import Buffer 

10else: 

11 from typing import Union 

12 

13 Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] 

14 

15try: 

16 try: 

17 import brotlicffi as brotli 

18 except ImportError: 

19 import brotli 

20 

21 HAS_BROTLI = True 

22except ImportError: 

23 HAS_BROTLI = False 

24 

25try: 

26 if sys.version_info >= (3, 14): 

27 from compression.zstd import ZstdDecompressor # noqa: I900 

28 else: # TODO(PY314): Remove mentions of backports.zstd across codebase 

29 from backports.zstd import ZstdDecompressor 

30 

31 HAS_ZSTD = True 

32except ImportError: 

33 HAS_ZSTD = False 

34 

35 

36MAX_SYNC_CHUNK_SIZE = 4096 

37 

38# Unlimited decompression constants - different libraries use different conventions 

39ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited 

40ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited 

41 

42 

43class ZLibCompressObjProtocol(Protocol): 

44 def compress(self, data: Buffer) -> bytes: ... 

45 def flush(self, mode: int = ..., /) -> bytes: ... 

46 

47 

48class ZLibDecompressObjProtocol(Protocol): 

49 def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ... 

50 def flush(self, length: int = ..., /) -> bytes: ... 

51 

52 @property 

53 def eof(self) -> bool: ... 

54 

55 @property 

56 def unconsumed_tail(self) -> bytes: ... 

57 

58 

59class ZLibBackendProtocol(Protocol): 

60 MAX_WBITS: int 

61 Z_FULL_FLUSH: int 

62 Z_SYNC_FLUSH: int 

63 Z_BEST_SPEED: int 

64 Z_FINISH: int 

65 

66 def compressobj( 

67 self, 

68 level: int = ..., 

69 method: int = ..., 

70 wbits: int = ..., 

71 memLevel: int = ..., 

72 strategy: int = ..., 

73 zdict: Buffer | None = ..., 

74 ) -> ZLibCompressObjProtocol: ... 

75 def decompressobj( 

76 self, wbits: int = ..., zdict: Buffer = ... 

77 ) -> ZLibDecompressObjProtocol: ... 

78 

79 def compress( 

80 self, data: Buffer, /, level: int = ..., wbits: int = ... 

81 ) -> bytes: ... 

82 def decompress( 

83 self, data: Buffer, /, wbits: int = ..., bufsize: int = ... 

84 ) -> bytes: ... 

85 

86 

87class CompressObjArgs(TypedDict, total=False): 

88 wbits: int 

89 strategy: int 

90 level: int 

91 

92 

93class ZLibBackendWrapper: 

94 def __init__(self, _zlib_backend: ZLibBackendProtocol): 

95 self._zlib_backend: ZLibBackendProtocol = _zlib_backend 

96 

97 @property 

98 def name(self) -> str: 

99 return getattr(self._zlib_backend, "__name__", "undefined") 

100 

101 @property 

102 def MAX_WBITS(self) -> int: 

103 return self._zlib_backend.MAX_WBITS 

104 

105 @property 

106 def Z_FULL_FLUSH(self) -> int: 

107 return self._zlib_backend.Z_FULL_FLUSH 

108 

109 @property 

110 def Z_SYNC_FLUSH(self) -> int: 

111 return self._zlib_backend.Z_SYNC_FLUSH 

112 

113 @property 

114 def Z_BEST_SPEED(self) -> int: 

115 return self._zlib_backend.Z_BEST_SPEED 

116 

117 @property 

118 def Z_FINISH(self) -> int: 

119 return self._zlib_backend.Z_FINISH 

120 

121 def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol: 

122 return self._zlib_backend.compressobj(*args, **kwargs) 

123 

124 def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol: 

125 return self._zlib_backend.decompressobj(*args, **kwargs) 

126 

127 def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: 

128 return self._zlib_backend.compress(data, *args, **kwargs) 

129 

130 def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes: 

131 return self._zlib_backend.decompress(data, *args, **kwargs) 

132 

133 # Everything not explicitly listed in the Protocol we just pass through 

134 def __getattr__(self, attrname: str) -> Any: 

135 return getattr(self._zlib_backend, attrname) 

136 

137 

138ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib) 

139 

140 

141def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None: 

142 ZLibBackend._zlib_backend = new_zlib_backend 

143 

144 

145def encoding_to_mode( 

146 encoding: str | None = None, 

147 suppress_deflate_header: bool = False, 

148) -> int: 

149 if encoding == "gzip": 

150 return 16 + ZLibBackend.MAX_WBITS 

151 

152 return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS 

153 

154 

155class DecompressionBaseHandler(ABC): 

156 def __init__( 

157 self, 

158 executor: Executor | None = None, 

159 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, 

160 ): 

161 """Base class for decompression handlers.""" 

162 self._executor = executor 

163 self._max_sync_chunk_size = max_sync_chunk_size 

164 

165 @abstractmethod 

166 def decompress_sync( 

167 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED 

168 ) -> bytes: 

169 """Decompress the given data.""" 

170 

171 async def decompress( 

172 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED 

173 ) -> bytes: 

174 """Decompress the given data.""" 

175 if ( 

176 self._max_sync_chunk_size is not None 

177 and len(data) > self._max_sync_chunk_size 

178 ): 

179 return await asyncio.get_event_loop().run_in_executor( 

180 self._executor, self.decompress_sync, data, max_length 

181 ) 

182 return self.decompress_sync(data, max_length) 

183 

184 @property 

185 @abstractmethod 

186 def data_available(self) -> bool: 

187 """Return True if more output is available by passing b"".""" 

188 

189 

190class ZLibCompressor: 

191 def __init__( 

192 self, 

193 encoding: str | None = None, 

194 suppress_deflate_header: bool = False, 

195 level: int | None = None, 

196 wbits: int | None = None, 

197 strategy: int | None = None, 

198 executor: Executor | None = None, 

199 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, 

200 ): 

201 self._executor = executor 

202 self._max_sync_chunk_size = max_sync_chunk_size 

203 self._mode = ( 

204 encoding_to_mode(encoding, suppress_deflate_header) 

205 if wbits is None 

206 else wbits 

207 ) 

208 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) 

209 

210 kwargs: CompressObjArgs = {} 

211 kwargs["wbits"] = self._mode 

212 if strategy is not None: 

213 kwargs["strategy"] = strategy 

214 if level is not None: 

215 kwargs["level"] = level 

216 self._compressor = self._zlib_backend.compressobj(**kwargs) 

217 

218 def compress_sync(self, data: Buffer) -> bytes: 

219 return self._compressor.compress(data) 

220 

221 async def compress(self, data: Buffer) -> bytes: 

222 """Compress the data and returned the compressed bytes. 

223 

224 Note that flush() must be called after the last call to compress() 

225 

226 If the data size is large than the max_sync_chunk_size, the compression 

227 will be done in the executor. Otherwise, the compression will be done 

228 in the event loop. 

229 

230 **WARNING: This method is NOT cancellation-safe when used with flush().** 

231 If this operation is cancelled, the compressor state may be corrupted. 

232 The connection MUST be closed after cancellation to avoid data corruption 

233 in subsequent compress operations. 

234 

235 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap 

236 compress() + flush() + send operations in a shield and lock to ensure atomicity. 

237 """ 

238 # For large payloads, offload compression to executor to avoid blocking event loop 

239 should_use_executor = ( 

240 self._max_sync_chunk_size is not None 

241 and len(data) > self._max_sync_chunk_size 

242 ) 

243 if should_use_executor: 

244 return await asyncio.get_running_loop().run_in_executor( 

245 self._executor, self._compressor.compress, data 

246 ) 

247 return self.compress_sync(data) 

248 

249 def flush(self, mode: int | None = None) -> bytes: 

250 """Flush the compressor synchronously. 

251 

252 **WARNING: This method is NOT cancellation-safe when called after compress().** 

253 The flush() operation accesses shared compressor state. If compress() was 

254 cancelled, calling flush() may result in corrupted data. The connection MUST 

255 be closed after compress() cancellation. 

256 

257 For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap 

258 compress() + flush() + send operations in a shield and lock to ensure atomicity. 

259 """ 

260 return self._compressor.flush( 

261 mode if mode is not None else self._zlib_backend.Z_FINISH 

262 ) 

263 

264 

265class ZLibDecompressor(DecompressionBaseHandler): 

266 def __init__( 

267 self, 

268 encoding: str | None = None, 

269 suppress_deflate_header: bool = False, 

270 executor: Executor | None = None, 

271 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, 

272 ): 

273 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) 

274 self._mode = encoding_to_mode(encoding, suppress_deflate_header) 

275 self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend) 

276 self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode) 

277 self._last_empty = False 

278 

279 def decompress_sync( 

280 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED 

281 ) -> bytes: 

282 result = self._decompressor.decompress( 

283 self._decompressor.unconsumed_tail + data, max_length 

284 ) 

285 # Only way to know that isal has no further data is checking we get no output 

286 self._last_empty = result == b"" 

287 return result 

288 

289 def flush(self, length: int = 0) -> bytes: 

290 return ( 

291 self._decompressor.flush(length) 

292 if length > 0 

293 else self._decompressor.flush() 

294 ) 

295 

296 @property 

297 def data_available(self) -> bool: 

298 return bool(self._decompressor.unconsumed_tail) or not self._last_empty 

299 

300 @property 

301 def eof(self) -> bool: 

302 return self._decompressor.eof 

303 

304 

305class BrotliDecompressor(DecompressionBaseHandler): 

306 # Supports both 'brotlipy' and 'Brotli' packages 

307 # since they share an import name. The top branches 

308 # are for 'brotlipy' and bottom branches for 'Brotli' 

309 def __init__( 

310 self, 

311 executor: Executor | None = None, 

312 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, 

313 ) -> None: 

314 """Decompress data using the Brotli library.""" 

315 if not HAS_BROTLI: 

316 raise RuntimeError( 

317 "The brotli decompression is not available. " 

318 "Please install `Brotli` module" 

319 ) 

320 self._obj = brotli.Decompressor() 

321 self._last_empty = False 

322 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) 

323 

324 def decompress_sync( 

325 self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED 

326 ) -> bytes: 

327 """Decompress the given data.""" 

328 if hasattr(self._obj, "decompress"): 

329 if max_length == ZLIB_MAX_LENGTH_UNLIMITED: 

330 result = cast(bytes, self._obj.decompress(data)) 

331 else: 

332 result = cast(bytes, self._obj.decompress(data, max_length)) 

333 else: 

334 if max_length == ZLIB_MAX_LENGTH_UNLIMITED: 

335 result = cast(bytes, self._obj.process(data)) 

336 else: 

337 result = cast(bytes, self._obj.process(data, max_length)) 

338 # Only way to know that brotli has no further data is checking we get no output 

339 self._last_empty = result == b"" 

340 return result 

341 

342 def flush(self) -> bytes: 

343 """Flush the decompressor.""" 

344 if hasattr(self._obj, "flush"): 

345 return cast(bytes, self._obj.flush()) 

346 return b"" 

347 

348 @property 

349 def data_available(self) -> bool: 

350 return not self._obj.is_finished() and not self._last_empty 

351 

352 

353class ZSTDDecompressor(DecompressionBaseHandler): 

354 def __init__( 

355 self, 

356 executor: Executor | None = None, 

357 max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE, 

358 ) -> None: 

359 if not HAS_ZSTD: 

360 raise RuntimeError( 

361 "The zstd decompression is not available. " 

362 "Please install `backports.zstd` module" 

363 ) 

364 self._obj = ZstdDecompressor() 

365 self._pending_unused_data: bytes | None = None 

366 super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size) 

367 

368 def decompress_sync( 

369 self, data: bytes, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED 

370 ) -> bytes: 

371 # zstd uses -1 for unlimited, while zlib uses 0 for unlimited 

372 # Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited) 

373 zstd_max_length = ( 

374 ZSTD_MAX_LENGTH_UNLIMITED 

375 if max_length == ZLIB_MAX_LENGTH_UNLIMITED 

376 else max_length 

377 ) 

378 if self._pending_unused_data is not None: 

379 data = self._pending_unused_data + data 

380 self._pending_unused_data = None 

381 result = self._obj.decompress(data, zstd_max_length) 

382 

383 # Handle multi-frame zstd streams. 

384 # https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1 

385 # ZstdDecompressor handles one frame only. When a frame ends, 

386 # eof becomes True and any trailing data goes to unused_data. 

387 # We create a fresh decompressor to continue with the next frame. 

388 while self._obj.eof and self._obj.unused_data: 

389 unused_data = self._obj.unused_data 

390 self._obj = ZstdDecompressor() 

391 if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED: 

392 zstd_max_length -= len(result) 

393 if zstd_max_length <= 0: 

394 self._pending_unused_data = unused_data 

395 break 

396 result += self._obj.decompress(unused_data, zstd_max_length) 

397 

398 # Frame ended exactly at chunk boundary — no unused_data, but the 

399 # next feed_data() call would fail on the spent decompressor. 

400 # Prepare a fresh one for the next chunk. 

401 if self._obj.eof: 

402 self._obj = ZstdDecompressor() 

403 

404 return result 

405 

406 def flush(self) -> bytes: 

407 return b"" 

408 

409 @property 

410 def data_available(self) -> bool: 

411 return ( 

412 not self._obj.needs_input and not self._obj.eof 

413 ) or self._pending_unused_data is not None