Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/smart_open/compression.py: 39%
56 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:57 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:57 +0000
1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
4#
5# This code is distributed under the terms and conditions
6# from the MIT License (MIT).
7#
8"""Implements the compression layer of the ``smart_open`` library."""
9import logging
10import os.path
12logger = logging.getLogger(__name__)
14_COMPRESSOR_REGISTRY = {}
16NO_COMPRESSION = 'disable'
17"""Use no compression. Read/write the data as-is."""
18INFER_FROM_EXTENSION = 'infer_from_extension'
19"""Determine the compression to use from the file extension.
21See get_supported_extensions().
22"""
25def get_supported_compression_types():
26 """Return the list of supported compression types available to open.
28 See compression paratemeter to smart_open.open().
29 """
30 return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()
33def get_supported_extensions():
34 """Return the list of file extensions for which we have registered compressors."""
35 return sorted(_COMPRESSOR_REGISTRY.keys())
38def register_compressor(ext, callback):
39 """Register a callback for transparently decompressing files with a specific extension.
41 Parameters
42 ----------
43 ext: str
44 The extension. Must include the leading period, e.g. ``.gz``.
45 callback: callable
46 The callback. It must accept two position arguments, file_obj and mode.
47 This function will be called when ``smart_open`` is opening a file with
48 the specified extension.
50 Examples
51 --------
53 Instruct smart_open to use the `lzma` module whenever opening a file
54 with a .xz extension (see README.rst for the complete example showing I/O):
56 >>> def _handle_xz(file_obj, mode):
57 ... import lzma
58 ... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
59 >>>
60 >>> register_compressor('.xz', _handle_xz)
62 """
63 if not (ext and ext[0] == '.'):
64 raise ValueError('ext must be a string starting with ., not %r' % ext)
65 ext = ext.lower()
66 if ext in _COMPRESSOR_REGISTRY:
67 logger.warning('overriding existing compression handler for %r', ext)
68 _COMPRESSOR_REGISTRY[ext] = callback
71def tweak_close(outer, inner):
72 """Ensure that closing the `outer` stream closes the `inner` stream as well.
74 Use this when your compression library's `close` method does not
75 automatically close the underlying filestream. See
76 https://github.com/RaRe-Technologies/smart_open/issues/630 for an
77 explanation why that is a problem for smart_open.
78 """
79 outer_close = outer.close
81 def close_both(*args):
82 nonlocal inner
83 try:
84 outer_close()
85 finally:
86 if inner:
87 inner, fp = None, inner
88 fp.close()
90 outer.close = close_both
93def _handle_bz2(file_obj, mode):
94 from bz2 import BZ2File
95 result = BZ2File(file_obj, mode)
96 tweak_close(result, file_obj)
97 return result
100def _handle_gzip(file_obj, mode):
101 import gzip
102 result = gzip.GzipFile(fileobj=file_obj, mode=mode)
103 tweak_close(result, file_obj)
104 return result
107def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
108 """
109 Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
111 If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
113 `file_obj` must either be a filehandle object, or a class which behaves like one.
115 If `filename` is specified, it will be used to extract the extension.
116 If not, the `file_obj.name` attribute is used as the filename.
118 """
119 if compression == NO_COMPRESSION:
120 return file_obj
121 elif compression == INFER_FROM_EXTENSION:
122 try:
123 filename = (filename or file_obj.name).lower()
124 except (AttributeError, TypeError):
125 logger.warning(
126 'unable to transparently decompress %r because it '
127 'seems to lack a string-like .name', file_obj
128 )
129 return file_obj
130 _, compression = os.path.splitext(filename)
132 if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
133 raise ValueError('transparent (de)compression unsupported for mode %r' % mode)
135 try:
136 callback = _COMPRESSOR_REGISTRY[compression]
137 except KeyError:
138 return file_obj
139 else:
140 return callback(file_obj, mode)
143#
144# NB. avoid using lambda here to make stack traces more readable.
145#
146register_compressor('.bz2', _handle_bz2)
147register_compressor('.gz', _handle_gzip)