Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/smart_open/compression.py: 65%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

81 statements  

1# 

2# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com> 

3# 

4# This code is distributed under the terms and conditions 

5# from the MIT License (MIT). 

6# 

7"""Implements the compression layer of the `smart_open` library.""" 

8 

9from __future__ import annotations 

10 

11import io 

12import logging 

13from pathlib import Path 

14from typing import IO, TYPE_CHECKING, Any 

15 

16if TYPE_CHECKING: 

17 from smart_open._typing import CompressionKwargs, Compressor 

18 

19logger = logging.getLogger(__name__) 

20 

21_COMPRESSOR_REGISTRY: dict[str, Compressor] = {} 

22 

23NO_COMPRESSION = "disable" 

24"""Use no compression. Read/write the data as-is.""" 

25INFER_FROM_EXTENSION = "infer_from_extension" 

26"""Determine the compression to use from the file extension. 

27 

28See get_supported_extensions(). 

29""" 

30 

31 

32def get_supported_compression_types() -> list[str]: 

33 """Return the list of supported compression types available to open. 

34 

35 See compression paratemeter to smart_open.open(). 

36 """ 

37 return [NO_COMPRESSION, INFER_FROM_EXTENSION, *get_supported_extensions()] 

38 

39 

40def get_supported_extensions() -> list[str]: 

41 """Return the list of file extensions for which we have registered compressors.""" 

42 return sorted(_COMPRESSOR_REGISTRY.keys()) 

43 

44 

45def register_compressor(ext: str, callback: Compressor) -> None: 

46 """Register a callback for transparently decompressing files with a specific extension. 

47 

48 Args: 

49 ext: The extension. Must include the leading period, e.g. `.gz`. 

50 callback: The callback. It must accept two positional arguments, file_obj and mode, 

51 and is recommended to also accept **kwargs so that whatever the caller passes 

52 via smart_open.open(..., compression_kwargs={...}) reaches the underlying 

53 library unchanged. Callbacks with the legacy (file_obj, mode) signature still 

54 work, but will raise TypeError if the caller supplies compression_kwargs 

55 that the callback doesn't declare. 

56 

57 Raises: 

58 ValueError: If `ext` does not start with a period. 

59 

60 Example: 

61 Instruct smart_open to use the `lzma` module whenever opening a file 

62 with a .xz extension (see README.md for the complete example showing I/O): 

63 

64 >>> def _handle_xz(file_obj, mode, **kwargs): 

65 ... import lzma 

66 ... 

67 ... return lzma.open(filename=file_obj, mode=mode, **kwargs) 

68 >>> 

69 >>> register_compressor(".xz", _handle_xz) 

70 

71 This is just an example: `lzma` is in the standard library and is registered by default. 

72 """ 

73 if not (ext and ext[0] == "."): 

74 msg = f"ext must be a string starting with ., not {ext!r}" 

75 raise ValueError(msg) 

76 ext = ext.lower() 

77 if ext in _COMPRESSOR_REGISTRY: 

78 logger.warning("overriding existing compression handler for %r", ext) 

79 _COMPRESSOR_REGISTRY[ext] = callback 

80 

81 

82def _maybe_wrap_buffered(file_obj: Any, mode: str) -> IO[bytes]: 

83 # https://github.com/piskvorky/smart_open/issues/760#issuecomment-1553971657 

84 result = file_obj 

85 if "b" in mode and "w" in mode: 

86 result = io.BufferedWriter(result) 

87 elif "b" in mode and "r" in mode: 

88 result = io.BufferedReader(result) 

89 return result 

90 

91 

92def _handle_bz2(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]: 

93 import bz2 

94 

95 result = bz2.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller 

96 return _maybe_wrap_buffered(result, mode) 

97 

98 

99def _handle_gzip(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]: 

100 import gzip 

101 

102 result = gzip.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller 

103 return _maybe_wrap_buffered(result, mode) 

104 

105 

106def _handle_zstd(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]: 

107 import sys 

108 

109 if sys.version_info >= (3, 14): 

110 from compression import zstd 

111 else: 

112 from backports import zstd 

113 # dynamic **kwargs cannot be matched against zstd.open()'s overloads, so go through Any 

114 zstd_open: Any = zstd.open 

115 result = zstd_open(file_obj, mode=mode, **kwargs) 

116 return _maybe_wrap_buffered(result, mode) 

117 

118 

119def _handle_xz(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]: 

120 import lzma 

121 

122 result = lzma.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller 

123 return _maybe_wrap_buffered(result, mode) 

124 

125 

126def _handle_lz4(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]: 

127 import lz4.frame 

128 

129 result = lz4.frame.open(file_obj, mode=mode, **kwargs) 

130 return _maybe_wrap_buffered(result, mode) 

131 

132 

133def compression_wrapper( 

134 file_obj: IO[Any], 

135 mode: str, 

136 compression: str = INFER_FROM_EXTENSION, 

137 filename: str | None = None, 

138 compression_kwargs: CompressionKwargs | None = None, 

139) -> IO[Any]: 

140 """Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension. 

141 

142 If the filename extension isn't recognized, simply return the original `file_obj` unchanged. 

143 

144 `file_obj` must either be a filehandle object, or a class which behaves like one. 

145 

146 If `filename` is specified, it will be used to extract the extension. 

147 If not, the `file_obj.name` attribute is used as the filename. 

148 

149 If `compression_kwargs` is specified, its contents are forwarded as keyword 

150 arguments to the registered compressor callback. 

151 """ 

152 if compression == NO_COMPRESSION: 

153 return file_obj 

154 if compression == INFER_FROM_EXTENSION: 

155 try: 

156 inferred_name = (filename or file_obj.name).lower() 

157 except (AttributeError, TypeError): 

158 logger.warning( 

159 "unable to transparently decompress %r because it seems to lack a string-like .name", file_obj 

160 ) 

161 return file_obj 

162 compression = Path(inferred_name).suffix 

163 

164 if compression in _COMPRESSOR_REGISTRY and mode.endswith("+"): 

165 msg = f"transparent (de)compression unsupported for mode {mode!r}" 

166 raise ValueError(msg) 

167 

168 try: 

169 callback = _COMPRESSOR_REGISTRY[compression] 

170 except KeyError: 

171 return file_obj 

172 

173 return callback(file_obj, mode, **(compression_kwargs or {})) 

174 

175 

176# 

177# NB. avoid using lambda here to make stack traces more readable. 

178# 

179register_compressor(".bz2", _handle_bz2) 

180register_compressor(".gz", _handle_gzip) 

181register_compressor(".zst", _handle_zstd) 

182register_compressor(".xz", _handle_xz) 

183register_compressor(".lz4", _handle_lz4)