1#
2# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
3#
4# This code is distributed under the terms and conditions
5# from the MIT License (MIT).
6#
7"""Implements the compression layer of the `smart_open` library."""
8
9from __future__ import annotations
10
11import io
12import logging
13from pathlib import Path
14from typing import IO, TYPE_CHECKING, Any
15
16if TYPE_CHECKING:
17 from smart_open._typing import CompressionKwargs, Compressor
18
19logger = logging.getLogger(__name__)
20
21_COMPRESSOR_REGISTRY: dict[str, Compressor] = {}
22
23NO_COMPRESSION = "disable"
24"""Use no compression. Read/write the data as-is."""
25INFER_FROM_EXTENSION = "infer_from_extension"
26"""Determine the compression to use from the file extension.
27
28See get_supported_extensions().
29"""
30
31
32def get_supported_compression_types() -> list[str]:
33 """Return the list of supported compression types available to open.
34
35 See compression paratemeter to smart_open.open().
36 """
37 return [NO_COMPRESSION, INFER_FROM_EXTENSION, *get_supported_extensions()]
38
39
40def get_supported_extensions() -> list[str]:
41 """Return the list of file extensions for which we have registered compressors."""
42 return sorted(_COMPRESSOR_REGISTRY.keys())
43
44
45def register_compressor(ext: str, callback: Compressor) -> None:
46 """Register a callback for transparently decompressing files with a specific extension.
47
48 Args:
49 ext: The extension. Must include the leading period, e.g. `.gz`.
50 callback: The callback. It must accept two positional arguments, file_obj and mode,
51 and is recommended to also accept **kwargs so that whatever the caller passes
52 via smart_open.open(..., compression_kwargs={...}) reaches the underlying
53 library unchanged. Callbacks with the legacy (file_obj, mode) signature still
54 work, but will raise TypeError if the caller supplies compression_kwargs
55 that the callback doesn't declare.
56
57 Raises:
58 ValueError: If `ext` does not start with a period.
59
60 Example:
61 Instruct smart_open to use the `lzma` module whenever opening a file
62 with a .xz extension (see README.md for the complete example showing I/O):
63
64 >>> def _handle_xz(file_obj, mode, **kwargs):
65 ... import lzma
66 ...
67 ... return lzma.open(filename=file_obj, mode=mode, **kwargs)
68 >>>
69 >>> register_compressor(".xz", _handle_xz)
70
71 This is just an example: `lzma` is in the standard library and is registered by default.
72 """
73 if not (ext and ext[0] == "."):
74 msg = f"ext must be a string starting with ., not {ext!r}"
75 raise ValueError(msg)
76 ext = ext.lower()
77 if ext in _COMPRESSOR_REGISTRY:
78 logger.warning("overriding existing compression handler for %r", ext)
79 _COMPRESSOR_REGISTRY[ext] = callback
80
81
82def _maybe_wrap_buffered(file_obj: Any, mode: str) -> IO[bytes]:
83 # https://github.com/piskvorky/smart_open/issues/760#issuecomment-1553971657
84 result = file_obj
85 if "b" in mode and "w" in mode:
86 result = io.BufferedWriter(result)
87 elif "b" in mode and "r" in mode:
88 result = io.BufferedReader(result)
89 return result
90
91
92def _handle_bz2(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]:
93 import bz2
94
95 result = bz2.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller
96 return _maybe_wrap_buffered(result, mode)
97
98
99def _handle_gzip(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]:
100 import gzip
101
102 result = gzip.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller
103 return _maybe_wrap_buffered(result, mode)
104
105
106def _handle_zstd(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]:
107 import sys
108
109 if sys.version_info >= (3, 14):
110 from compression import zstd
111 else:
112 from backports import zstd
113 # dynamic **kwargs cannot be matched against zstd.open()'s overloads, so go through Any
114 zstd_open: Any = zstd.open
115 result = zstd_open(file_obj, mode=mode, **kwargs)
116 return _maybe_wrap_buffered(result, mode)
117
118
119def _handle_xz(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]:
120 import lzma
121
122 result = lzma.open(filename=file_obj, mode=mode, **kwargs) # noqa: SIM115 # returns the file object to caller
123 return _maybe_wrap_buffered(result, mode)
124
125
126def _handle_lz4(file_obj: IO[bytes], mode: str, **kwargs: Any) -> IO[Any]:
127 import lz4.frame
128
129 result = lz4.frame.open(file_obj, mode=mode, **kwargs)
130 return _maybe_wrap_buffered(result, mode)
131
132
133def compression_wrapper(
134 file_obj: IO[Any],
135 mode: str,
136 compression: str = INFER_FROM_EXTENSION,
137 filename: str | None = None,
138 compression_kwargs: CompressionKwargs | None = None,
139) -> IO[Any]:
140 """Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
141
142 If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
143
144 `file_obj` must either be a filehandle object, or a class which behaves like one.
145
146 If `filename` is specified, it will be used to extract the extension.
147 If not, the `file_obj.name` attribute is used as the filename.
148
149 If `compression_kwargs` is specified, its contents are forwarded as keyword
150 arguments to the registered compressor callback.
151 """
152 if compression == NO_COMPRESSION:
153 return file_obj
154 if compression == INFER_FROM_EXTENSION:
155 try:
156 inferred_name = (filename or file_obj.name).lower()
157 except (AttributeError, TypeError):
158 logger.warning(
159 "unable to transparently decompress %r because it seems to lack a string-like .name", file_obj
160 )
161 return file_obj
162 compression = Path(inferred_name).suffix
163
164 if compression in _COMPRESSOR_REGISTRY and mode.endswith("+"):
165 msg = f"transparent (de)compression unsupported for mode {mode!r}"
166 raise ValueError(msg)
167
168 try:
169 callback = _COMPRESSOR_REGISTRY[compression]
170 except KeyError:
171 return file_obj
172
173 return callback(file_obj, mode, **(compression_kwargs or {}))
174
175
176#
177# NB. avoid using lambda here to make stack traces more readable.
178#
179register_compressor(".bz2", _handle_bz2)
180register_compressor(".gz", _handle_gzip)
181register_compressor(".zst", _handle_zstd)
182register_compressor(".xz", _handle_xz)
183register_compressor(".lz4", _handle_lz4)