1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
4#
5# This code is distributed under the terms and conditions
6# from the MIT License (MIT).
7#
8"""Implements the compression layer of the `smart_open` library."""
9import io
10import logging
11import os.path
12
13logger = logging.getLogger(__name__)
14
15_COMPRESSOR_REGISTRY = {}
16
17NO_COMPRESSION = 'disable'
18"""Use no compression. Read/write the data as-is."""
19INFER_FROM_EXTENSION = 'infer_from_extension'
20"""Determine the compression to use from the file extension.
21
22See get_supported_extensions().
23"""
24
25
26def get_supported_compression_types():
27 """Return the list of supported compression types available to open.
28
29 See compression paratemeter to smart_open.open().
30 """
31 return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()
32
33
34def get_supported_extensions():
35 """Return the list of file extensions for which we have registered compressors."""
36 return sorted(_COMPRESSOR_REGISTRY.keys())
37
38
39def register_compressor(ext, callback):
40 """Register a callback for transparently decompressing files with a specific extension.
41
42 Parameters
43 ----------
44 ext: str
45 The extension. Must include the leading period, e.g. `.gz`.
46 callback: callable
47 The callback. It must accept two position arguments, file_obj and mode.
48 This function will be called when `smart_open` is opening a file with
49 the specified extension.
50
51 Examples
52 --------
53
54 Instruct smart_open to use the `lzma` module whenever opening a file
55 with a .xz extension (see README.rst for the complete example showing I/O):
56
57 >>> def _handle_xz(file_obj, mode):
58 ... import lzma
59 ... return lzma.LZMAFile(filename=file_obj, mode=mode)
60 >>>
61 >>> register_compressor('.xz', _handle_xz)
62
63 This is just an example: `lzma` is in the standard library and is registered by default.
64
65 """
66 if not (ext and ext[0] == '.'):
67 raise ValueError('ext must be a string starting with ., not %r' % ext)
68 ext = ext.lower()
69 if ext in _COMPRESSOR_REGISTRY:
70 logger.warning('overriding existing compression handler for %r', ext)
71 _COMPRESSOR_REGISTRY[ext] = callback
72
73
74def tweak_close(outer, inner):
75 """Ensure that closing the `outer` stream closes the `inner` stream as well.
76
77 Deprecated: `smart_open.open().__exit__` now always calls `__exit__` on the
78 underlying filestream.
79
80 Use this when your compression library's `close` method does not
81 automatically close the underlying filestream. See
82 https://github.com/piskvorky/smart_open/issues/630 for an
83 explanation why that is a problem for smart_open.
84 """
85 outer_close = outer.close
86
87 def close_both(*args):
88 nonlocal inner
89 try:
90 outer_close()
91 finally:
92 if inner:
93 inner, fp = None, inner
94 fp.close()
95
96 outer.close = close_both
97
98
99def _maybe_wrap_buffered(file_obj, mode):
100 # https://github.com/piskvorky/smart_open/issues/760#issuecomment-1553971657
101 result = file_obj
102 if "b" in mode and "w" in mode:
103 result = io.BufferedWriter(result)
104 elif "b" in mode and "r" in mode:
105 result = io.BufferedReader(result)
106 return result
107
108
109def _handle_bz2(file_obj, mode):
110 import bz2
111 result = bz2.open(filename=file_obj, mode=mode)
112 return _maybe_wrap_buffered(result, mode)
113
114
115def _handle_gzip(file_obj, mode):
116 import gzip
117 result = gzip.open(filename=file_obj, mode=mode)
118 return _maybe_wrap_buffered(result, mode)
119
120
121def _handle_zstd(file_obj, mode):
122 import zstandard
123 result = zstandard.open(filename=file_obj, mode=mode)
124 return _maybe_wrap_buffered(result, mode)
125
126
127def _handle_xz(file_obj, mode):
128 import lzma
129 result = lzma.open(filename=file_obj, mode=mode)
130 return _maybe_wrap_buffered(result, mode)
131
132
133def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
134 """
135 Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
136
137 If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
138
139 `file_obj` must either be a filehandle object, or a class which behaves like one.
140
141 If `filename` is specified, it will be used to extract the extension.
142 If not, the `file_obj.name` attribute is used as the filename.
143
144 """
145 if compression == NO_COMPRESSION:
146 return file_obj
147 elif compression == INFER_FROM_EXTENSION:
148 try:
149 filename = (filename or file_obj.name).lower()
150 except (AttributeError, TypeError):
151 logger.warning(
152 'unable to transparently decompress %r because it '
153 'seems to lack a string-like .name', file_obj
154 )
155 return file_obj
156 _, compression = os.path.splitext(filename)
157
158 if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
159 raise ValueError('transparent (de)compression unsupported for mode %r' % mode)
160
161 try:
162 callback = _COMPRESSOR_REGISTRY[compression]
163 except KeyError:
164 return file_obj
165 else:
166 return callback(file_obj, mode)
167
168
169#
170# NB. avoid using lambda here to make stack traces more readable.
171#
172register_compressor('.bz2', _handle_bz2)
173register_compressor('.gz', _handle_gzip)
174register_compressor('.zst', _handle_zstd)
175register_compressor('.xz', _handle_xz)