1from __future__ import annotations
2
3# built-in
4import codecs
5import math
6from collections import Counter
7from fractions import Fraction
8from itertools import groupby, permutations
9from typing import Any, Sequence, TypeVar
10
11# app
12from .base import Base as _Base
13
14
15try:
16 # built-in
17 import lzma
18except ImportError:
19 lzma = None # type: ignore[assignment]
20
21
22__all__ = [
23 'ArithNCD', 'LZMANCD', 'BZ2NCD', 'RLENCD', 'BWTRLENCD', 'ZLIBNCD',
24 'SqrtNCD', 'EntropyNCD',
25
26 'bz2_ncd', 'lzma_ncd', 'arith_ncd', 'rle_ncd', 'bwtrle_ncd', 'zlib_ncd',
27 'sqrt_ncd', 'entropy_ncd',
28]
29T = TypeVar('T')
30
31
32class _NCDBase(_Base):
33 """Normalized compression distance (NCD)
34
35 https://articles.orsinium.dev/other/ncd/
36 https://en.wikipedia.org/wiki/Normalized_compression_distance#Normalized_compression_distance
37 """
38 qval = 1
39
40 def __init__(self, qval: int = 1) -> None:
41 self.qval = qval
42
43 def maximum(self, *sequences) -> int:
44 return 1
45
46 def _get_size(self, data: str) -> float:
47 return len(self._compress(data))
48
49 def _compress(self, data: str) -> Any:
50 raise NotImplementedError
51
52 def __call__(self, *sequences) -> float:
53 if not sequences:
54 return 0
55 sequences = self._get_sequences(*sequences)
56
57 concat_len = float('Inf')
58 empty = type(sequences[0])()
59 for mutation in permutations(sequences):
60 if isinstance(empty, (str, bytes)):
61 data = empty.join(mutation)
62 else:
63 data = sum(mutation, empty)
64 concat_len = min(concat_len, self._get_size(data)) # type: ignore[arg-type]
65
66 compressed_lens = [self._get_size(s) for s in sequences]
67 max_len = max(compressed_lens)
68 if max_len == 0:
69 return 0
70 return (concat_len - min(compressed_lens) * (len(sequences) - 1)) / max_len
71
72
73class _BinaryNCDBase(_NCDBase):
74
75 def __init__(self) -> None:
76 pass
77
78 def __call__(self, *sequences) -> float:
79 if not sequences:
80 return 0
81 if isinstance(sequences[0], str):
82 sequences = tuple(s.encode('utf-8') for s in sequences)
83 return super().__call__(*sequences)
84
85
86class ArithNCD(_NCDBase):
87 """Arithmetic coding
88
89 https://github.com/gw-c/arith
90 http://www.drdobbs.com/cpp/data-compression-with-arithmetic-encodin/240169251
91 https://en.wikipedia.org/wiki/Arithmetic_coding
92 """
93
94 def __init__(self, base: int = 2, terminator: str | None = None, qval: int = 1) -> None:
95 self.base = base
96 self.terminator = terminator
97 self.qval = qval
98
99 def _make_probs(self, *sequences) -> dict[str, tuple[Fraction, Fraction]]:
100 """
101 https://github.com/gw-c/arith/blob/master/arith.py
102 """
103 sequences = self._get_counters(*sequences)
104 counts = self._sum_counters(*sequences)
105 if self.terminator is not None:
106 counts[self.terminator] = 1
107 total_letters = sum(counts.values())
108
109 prob_pairs = {}
110 cumulative_count = 0
111 for char, current_count in counts.most_common():
112 prob_pairs[char] = (
113 Fraction(cumulative_count, total_letters),
114 Fraction(current_count, total_letters),
115 )
116 cumulative_count += current_count
117 assert cumulative_count == total_letters
118 return prob_pairs
119
120 def _get_range(
121 self,
122 data: str,
123 probs: dict[str, tuple[Fraction, Fraction]],
124 ) -> tuple[Fraction, Fraction]:
125 if self.terminator is not None:
126 if self.terminator in data:
127 data = data.replace(self.terminator, '')
128 data += self.terminator
129
130 start = Fraction(0, 1)
131 width = Fraction(1, 1)
132 for char in data:
133 prob_start, prob_width = probs[char]
134 start += prob_start * width
135 width *= prob_width
136 return start, start + width
137
138 def _compress(self, data: str) -> Fraction:
139 probs = self._make_probs(data)
140 start, end = self._get_range(data=data, probs=probs)
141 output_fraction = Fraction(0, 1)
142 output_denominator = 1
143 while not (start <= output_fraction < end):
144 output_numerator = 1 + ((start.numerator * output_denominator) // start.denominator)
145 output_fraction = Fraction(output_numerator, output_denominator)
146 output_denominator *= 2
147 return output_fraction
148
149 def _get_size(self, data: str) -> int:
150 numerator = self._compress(data).numerator
151 if numerator == 0:
152 return 0
153 return math.ceil(math.log(numerator, self.base))
154
155
156class RLENCD(_NCDBase):
157 """Run-length encoding
158
159 https://en.wikipedia.org/wiki/Run-length_encoding
160 """
161
162 def _compress(self, data: Sequence) -> str:
163 new_data = []
164 for k, g in groupby(data):
165 n = len(list(g))
166 if n > 2:
167 new_data.append(str(n) + k)
168 elif n == 1:
169 new_data.append(k)
170 else:
171 new_data.append(2 * k)
172 return ''.join(new_data)
173
174
175class BWTRLENCD(RLENCD):
176 """
177 https://en.wikipedia.org/wiki/Burrows%E2%80%93Wheeler_transform
178 https://en.wikipedia.org/wiki/Run-length_encoding
179 """
180
181 def __init__(self, terminator: str = '\0') -> None:
182 self.terminator: Any = terminator
183
184 def _compress(self, data: str) -> str:
185 if not data:
186 data = self.terminator
187 elif self.terminator not in data:
188 data += self.terminator
189 modified = sorted(data[i:] + data[:i] for i in range(len(data)))
190 empty = type(data)()
191 data = empty.join(subdata[-1] for subdata in modified)
192 return super()._compress(data)
193
194
195# -- NORMAL COMPRESSORS -- #
196
197
198class SqrtNCD(_NCDBase):
199 """Square Root based NCD
200
201 Size of compressed data equals to sum of square roots of counts of every
202 element in the input sequence.
203 """
204
205 def __init__(self, qval: int = 1) -> None:
206 self.qval = qval
207
208 def _compress(self, data: Sequence[T]) -> dict[T, float]:
209 return {element: math.sqrt(count) for element, count in Counter(data).items()}
210
211 def _get_size(self, data: Sequence) -> float:
212 return sum(self._compress(data).values())
213
214
215class EntropyNCD(_NCDBase):
216 """Entropy based NCD
217
218 Get Entropy of input sequence as a size of compressed data.
219
220 https://en.wikipedia.org/wiki/Entropy_(information_theory)
221 https://en.wikipedia.org/wiki/Entropy_encoding
222 """
223
224 def __init__(self, qval: int = 1, coef: int = 1, base: int = 2) -> None:
225 self.qval = qval
226 self.coef = coef
227 self.base = base
228
229 def _compress(self, data: Sequence) -> float:
230 total_count = len(data)
231 entropy = 0.0
232 for element_count in Counter(data).values():
233 p = element_count / total_count
234 entropy -= p * math.log(p, self.base)
235 assert entropy >= 0
236 return entropy
237
238 # # redundancy:
239 # unique_count = len(counter)
240 # absolute_entropy = math.log(unique_count, 2) / unique_count
241 # return absolute_entropy - entropy / unique_count
242
243 def _get_size(self, data: Sequence) -> float:
244 return self.coef + self._compress(data)
245
246
247# -- BINARY COMPRESSORS -- #
248
249
250class BZ2NCD(_BinaryNCDBase):
251 """
252 https://en.wikipedia.org/wiki/Bzip2
253 """
254
255 def _compress(self, data: str | bytes) -> bytes:
256 return codecs.encode(data, 'bz2_codec')[15:]
257
258
259class LZMANCD(_BinaryNCDBase):
260 """
261 https://en.wikipedia.org/wiki/LZMA
262 """
263
264 def _compress(self, data: bytes) -> bytes:
265 if not lzma:
266 raise ImportError('Please, install the PylibLZMA module')
267 return lzma.compress(data)[14:]
268
269
270class ZLIBNCD(_BinaryNCDBase):
271 """
272 https://en.wikipedia.org/wiki/Zlib
273 """
274
275 def _compress(self, data: str | bytes) -> bytes:
276 return codecs.encode(data, 'zlib_codec')[2:]
277
278
279arith_ncd = ArithNCD()
280bwtrle_ncd = BWTRLENCD()
281bz2_ncd = BZ2NCD()
282lzma_ncd = LZMANCD()
283rle_ncd = RLENCD()
284zlib_ncd = ZLIBNCD()
285sqrt_ncd = SqrtNCD()
286entropy_ncd = EntropyNCD()