Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/bitarray/util.py: 16%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (c) 2019 - 2025, Ilan Schnell; All Rights Reserved
2# bitarray is published under the PSF license.
3#
4# Author: Ilan Schnell
5"""
6Useful utilities for working with bitarrays.
7"""
8import os
9import sys
10import math
11import random
13from bitarray import bitarray, bits2bytes
15from bitarray._util import (
16 zeros, ones, count_n, parity, _ssqi, xor_indices,
17 count_and, count_or, count_xor, any_and, subset,
18 correspond_all, byteswap,
19 serialize, deserialize,
20 ba2hex, hex2ba,
21 ba2base, base2ba,
22 sc_encode, sc_decode,
23 vl_encode, vl_decode,
24 canonical_decode,
25)
27__all__ = [
28 'zeros', 'ones', 'urandom', 'random_k', 'random_p', 'gen_primes',
29 'pprint', 'strip', 'count_n',
30 'parity', 'sum_indices', 'xor_indices',
31 'count_and', 'count_or', 'count_xor', 'any_and', 'subset',
32 'correspond_all', 'byteswap', 'intervals',
33 'ba2hex', 'hex2ba',
34 'ba2base', 'base2ba',
35 'ba2int', 'int2ba',
36 'serialize', 'deserialize',
37 'sc_encode', 'sc_decode',
38 'vl_encode', 'vl_decode',
39 'huffman_code', 'canonical_huffman', 'canonical_decode',
40]
43def urandom(__length, endian=None):
44 """urandom(n, /, endian=None) -> bitarray
46Return random bitarray of length `n` (uses `os.urandom()`).
47"""
48 a = bitarray(os.urandom(bits2bytes(__length)), endian)
49 del a[__length:]
50 return a
53def random_k(__n, k, endian=None):
54 """random_k(n, /, k, endian=None) -> bitarray
56Return (pseudo-) random bitarray of length `n` with `k` elements
57set to one. Mathematically equivalent to setting (in a bitarray of
58length `n`) all bits at indices `random.sample(range(n), k)` to one.
59The random bitarrays are reproducible when giving Python's `random.seed()`
60a specific seed value.
61"""
62 r = _Random(__n, endian)
63 if not isinstance(k, int):
64 raise TypeError("int expected, got '%s'" % type(k).__name__)
66 return r.random_k(k)
69def random_p(__n, p=0.5, endian=None):
70 """random_p(n, /, p=0.5, endian=None) -> bitarray
72Return (pseudo-) random bitarray of length `n`, where each bit has
73probability `p` of being one (independent of any other bits). Mathematically
74equivalent to `bitarray((random() < p for _ in range(n)), endian)`, but much
75faster for large `n`. The random bitarrays are reproducible when giving
76Python's `random.seed()` with a specific seed value.
78This function requires Python 3.12 or higher, as it depends on the standard
79library function `random.binomialvariate()`. Raises `NotImplementedError`
80when Python version is too low.
81"""
82 if sys.version_info[:2] < (3, 12):
83 raise NotImplementedError("bitarray.util.random_p() requires "
84 "Python 3.12 or higher")
85 r = _Random(__n, endian)
86 return r.random_p(p)
89class _Random:
91 # The main reason for this class it to enable testing functionality
92 # individually in the test class Random_P_Tests in 'test_util.py'.
93 # The test class also contains many comments and explanations.
94 # To better understand how the algorithm works, see ./doc/random_p.rst
95 # See also, VerificationTests in devel/test_random.py
97 # maximal number of calls to .random_half() in .combine()
98 M = 8
100 # number of resulting probability intervals
101 K = 1 << M
103 # limit for setting individual bits randomly
104 SMALL_P = 0.01
106 def __init__(self, n=0, endian=None):
107 self.n = n
108 self.nbytes = bits2bytes(n)
109 self.endian = endian
111 def random_half(self):
112 """
113 Return bitarray with each bit having probability p = 1/2 of being 1.
114 """
115 nbytes = self.nbytes
116 # use random module function for reproducibility (not urandom())
117 b = random.getrandbits(8 * nbytes).to_bytes(nbytes, 'little')
118 a = bitarray(b, self.endian)
119 del a[self.n:]
120 return a
122 def op_seq(self, i):
123 """
124 Return bitarray containing operator sequence.
125 Each item represents a bitwise operation: 0: AND 1: OR
126 After applying the sequence (see .combine_half()), we
127 obtain a bitarray with probability q = i / K
128 """
129 if not 0 < i < self.K:
130 raise ValueError("0 < i < %d, got i = %d" % (self.K, i))
132 # sequence of &, | operations - least significant operations first
133 a = bitarray(i.to_bytes(2, byteorder="little"), "little")
134 return a[a.index(1) + 1 : self.M]
136 def combine_half(self, seq):
137 """
138 Combine random bitarrays with probability 1/2
139 according to given operator sequence.
140 """
141 a = self.random_half()
142 for k in seq:
143 if k:
144 a |= self.random_half()
145 else:
146 a &= self.random_half()
147 return a
149 def random_k(self, k):
150 n = self.n
151 # error check inputs and handle edge cases
152 if k <= 0 or k >= n:
153 if k == 0:
154 return zeros(n, self.endian)
155 if k == n:
156 return ones(n, self.endian)
157 raise ValueError("k must be in range 0 <= k <= n, got %s" % k)
159 # exploit symmetry to establish: k <= n // 2
160 if k > n // 2:
161 a = self.random_k(n - k)
162 a.invert() # use in-place to avoid copying
163 return a
165 # decide on sequence, see VerificationTests devel/test_random.py
166 if k < 16 or k * self.K < 3 * n:
167 i = 0
168 else:
169 p = k / n # p <= 0.5
170 p -= (0.2 - 0.4 * p) / math.sqrt(n)
171 i = int(p * (self.K + 1))
173 # combine random bitarrays using bitwise AND and OR operations
174 if i < 3:
175 a = zeros(n, self.endian)
176 diff = -k
177 else:
178 a = self.combine_half(self.op_seq(i))
179 diff = a.count() - k
181 randrange = random.randrange
182 if diff < 0: # not enough bits 1 - increase count
183 for _ in range(-diff):
184 i = randrange(n)
185 while a[i]:
186 i = randrange(n)
187 a[i] = 1
188 elif diff > 0: # too many bits 1 - decrease count
189 for _ in range(diff):
190 i = randrange(n)
191 while not a[i]:
192 i = randrange(n)
193 a[i] = 0
195 return a
197 def random_p(self, p):
198 # error check inputs and handle edge cases
199 if p <= 0.0 or p == 0.5 or p >= 1.0:
200 if p == 0.0:
201 return zeros(self.n, self.endian)
202 if p == 0.5:
203 return self.random_half()
204 if p == 1.0:
205 return ones(self.n, self.endian)
206 raise ValueError("p must be in range 0.0 <= p <= 1.0, got %s" % p)
208 # for small n, use literal definition
209 if self.n < 16:
210 return bitarray((random.random() < p for _ in range(self.n)),
211 self.endian)
213 # exploit symmetry to establish: p < 0.5
214 if p > 0.5:
215 a = self.random_p(1.0 - p)
216 a.invert() # use in-place to avoid copying
217 return a
219 # for small p, set randomly individual bits
220 if p < self.SMALL_P:
221 return self.random_k(random.binomialvariate(self.n, p))
223 # calculate operator sequence
224 i = int(p * self.K)
225 if p * (self.K + 1) > i + 1: # see devel/test_random.py
226 i += 1
227 seq = self.op_seq(i)
228 q = i / self.K
230 # when n is small compared to number of operations, also use literal
231 if self.n < 100 and self.nbytes <= len(seq) + 3 * bool(q != p):
232 return bitarray((random.random() < p for _ in range(self.n)),
233 self.endian)
235 # combine random bitarrays using bitwise AND and OR operations
236 a = self.combine_half(seq)
237 if q < p:
238 x = (p - q) / (1.0 - q)
239 a |= self.random_p(x)
240 elif q > p:
241 x = p / q
242 a &= self.random_p(x)
244 return a
247def gen_primes(__n, endian=None, odd=False):
248 """gen_primes(n, /, endian=None, odd=False) -> bitarray
250Generate a bitarray of length `n` in which active indices are prime numbers.
251By default (`odd=False`), active indices correspond to prime numbers directly.
252When `odd=True`, only odd prime numbers are represented in the resulting
253bitarray `a`, and `a[i]` corresponds to `2*i+1` being prime or not.
254"""
255 n = int(__n)
256 if n < 0:
257 raise ValueError("bitarray length must be >= 0")
259 if odd:
260 a = ones(105, endian) # 105 = 3 * 5 * 7
261 a[1::3] = 0
262 a[2::5] = 0
263 a[3::7] = 0
264 f = "01110110"
265 else:
266 a = ones(210, endian) # 210 = 2 * 3 * 5 * 7
267 for i in 2, 3, 5, 7:
268 a[::i] = 0
269 f = "00110101"
271 # repeating the array many times is faster than setting the multiples
272 # of the low primes to 0
273 a *= (n + len(a) - 1) // len(a)
274 a[:8] = bitarray(f, endian)
275 del a[n:]
276 # perform sieve starting at 11
277 if odd:
278 for i in a.search(1, 5, int(math.sqrt(n // 2) + 1.0)): # 11//2 = 5
279 j = 2 * i + 1
280 a[(j * j) // 2 :: j] = 0
281 else:
282 # i*i is always odd, and even bits are already set to 0: use step 2*i
283 for i in a.search(1, 11, int(math.sqrt(n) + 1.0)):
284 a[i * i :: 2 * i] = 0
285 return a
288def sum_indices(__a, mode=1):
289 """sum_indices(a, /, mode=1) -> int
291Return sum of indices of all active bits in bitarray `a`.
292Equivalent to `sum(i for i, v in enumerate(a) if v)`.
293`mode=2` sums square of indices.
294"""
295 if mode not in (1, 2):
296 raise ValueError("unexpected mode %r" % mode)
298 # For details see: devel/test_sum_indices.py
299 n = 1 << 19 # block size 512 Kbits
300 if len(__a) <= n: # shortcut for single block
301 return _ssqi(__a, mode)
303 # Constants
304 m = n // 8 # block size in bytes
305 o1 = n * (n - 1) // 2
306 o2 = o1 * (2 * n - 1) // 3
308 nblocks = (len(__a) + n - 1) // n
309 padbits = __a.padbits
310 sm = 0
311 for i in range(nblocks):
312 # use memoryview to avoid copying memory
313 v = memoryview(__a)[i * m : (i + 1) * m]
314 block = bitarray(None, __a.endian, buffer=v)
315 if padbits and i == nblocks - 1:
316 if block.readonly:
317 block = bitarray(block)
318 block[-padbits:] = 0
320 k = block.count()
321 if k:
322 y = n * i
323 z1 = o1 if k == n else _ssqi(block)
324 if mode == 1:
325 sm += k * y + z1
326 else:
327 z2 = o2 if k == n else _ssqi(block, 2)
328 sm += (k * y + 2 * z1) * y + z2
330 return sm
333def pprint(__a, stream=None, group=8, indent=4, width=80):
334 """pprint(bitarray, /, stream=None, group=8, indent=4, width=80)
336Pretty-print bitarray object to `stream`, defaults is `sys.stdout`.
337By default, bits are grouped in bytes (8 bits), and 64 bits per line.
338Non-bitarray objects are printed using `pprint.pprint()`.
339"""
340 if stream is None:
341 stream = sys.stdout
343 if not isinstance(__a, bitarray):
344 import pprint as _pprint
345 _pprint.pprint(__a, stream=stream, indent=indent, width=width)
346 return
348 group = int(group)
349 if group < 1:
350 raise ValueError('group must be >= 1')
351 indent = int(indent)
352 if indent < 0:
353 raise ValueError('indent must be >= 0')
354 width = int(width)
355 if width <= indent:
356 raise ValueError('width must be > %d (indent)' % indent)
358 gpl = (width - indent) // (group + 1) # groups per line
359 epl = group * gpl # elements per line
360 if epl == 0:
361 epl = width - indent - 2
362 type_name = type(__a).__name__
363 # here 4 is len("'()'")
364 multiline = len(type_name) + 4 + len(__a) + len(__a) // group >= width
365 if multiline:
366 quotes = "'''"
367 elif __a:
368 quotes = "'"
369 else:
370 quotes = ""
372 stream.write("%s(%s" % (type_name, quotes))
373 for i, b in enumerate(__a):
374 if multiline and i % epl == 0:
375 stream.write('\n%s' % (indent * ' '))
376 if i % group == 0 and i % epl != 0:
377 stream.write(' ')
378 stream.write(str(b))
380 if multiline:
381 stream.write('\n')
383 stream.write("%s)\n" % quotes)
384 stream.flush()
387def strip(__a, mode='right'):
388 """strip(bitarray, /, mode='right') -> bitarray
390Return a new bitarray with zeros stripped from left, right or both ends.
391Allowed values for mode are the strings: `left`, `right`, `both`
392"""
393 if not isinstance(mode, str):
394 raise TypeError("str expected for mode, got '%s'" %
395 type(__a).__name__)
396 if mode not in ('left', 'right', 'both'):
397 raise ValueError("mode must be 'left', 'right' or 'both', got %r" %
398 mode)
400 start = None if mode == 'right' else __a.find(1)
401 if start == -1:
402 return __a[:0]
403 stop = None if mode == 'left' else __a.find(1, right=1) + 1
404 return __a[start:stop]
407def intervals(__a):
408 """intervals(bitarray, /) -> iterator
410Compute all uninterrupted intervals of 1s and 0s, and return an
411iterator over tuples `(value, start, stop)`. The intervals are guaranteed
412to be in order, and their size is always non-zero (`stop - start > 0`).
413"""
414 try:
415 value = __a[0] # value of current interval
416 except IndexError:
417 return
418 n = len(__a)
419 stop = 0 # "previous" stop - becomes next start
421 while stop < n:
422 start = stop
423 # assert __a[start] == value
424 try: # find next occurrence of opposite value
425 stop = __a.index(not value, start)
426 except ValueError:
427 stop = n
428 yield int(value), start, stop
429 value = not value # next interval has opposite value
432def ba2int(__a, signed=False):
433 """ba2int(bitarray, /, signed=False) -> int
435Convert the given bitarray to an integer.
436The bit-endianness of the bitarray is respected.
437`signed` indicates whether two's complement is used to represent the integer.
438"""
439 if not isinstance(__a, bitarray):
440 raise TypeError("bitarray expected, got '%s'" % type(__a).__name__)
441 length = len(__a)
442 if length == 0:
443 raise ValueError("non-empty bitarray expected")
445 if __a.padbits:
446 pad = zeros(__a.padbits, __a.endian)
447 __a = __a + pad if __a.endian == "little" else pad + __a
449 res = int.from_bytes(__a.tobytes(), byteorder=__a.endian)
451 if signed and res >> length - 1:
452 res -= 1 << length
453 return res
456def int2ba(__i, length=None, endian=None, signed=False):
457 """int2ba(int, /, length=None, endian=None, signed=False) -> bitarray
459Convert the given integer to a bitarray (with given bit-endianness,
460and no leading (big-endian) / trailing (little-endian) zeros), unless
461the `length` of the bitarray is provided. An `OverflowError` is raised
462if the integer is not representable with the given number of bits.
463`signed` determines whether two's complement is used to represent the integer,
464and requires `length` to be provided.
465"""
466 if not isinstance(__i, int):
467 raise TypeError("int expected, got '%s'" % type(__i).__name__)
468 if length is not None:
469 if not isinstance(length, int):
470 raise TypeError("int expected for argument 'length'")
471 if length <= 0:
472 raise ValueError("length must be > 0")
474 if signed:
475 if length is None:
476 raise TypeError("signed requires argument 'length'")
477 m = 1 << length - 1
478 if not (-m <= __i < m):
479 raise OverflowError("signed integer not in range(%d, %d), "
480 "got %d" % (-m, m, __i))
481 if __i < 0:
482 __i += 1 << length
483 else: # unsigned
484 if length and __i >> length:
485 raise OverflowError("unsigned integer not in range(0, %d), "
486 "got %d" % (1 << length, __i))
488 a = bitarray(0, endian)
489 b = __i.to_bytes(bits2bytes(__i.bit_length()), byteorder=a.endian)
490 a.frombytes(b)
491 le = a.endian == 'little'
492 if length is None:
493 return strip(a, 'right' if le else 'left') if a else a + '0'
495 if len(a) > length:
496 return a[:length] if le else a[-length:]
497 if len(a) == length:
498 return a
499 # len(a) < length, we need padding
500 pad = zeros(length - len(a), a.endian)
501 return a + pad if le else pad + a
503# ------------------------------ Huffman coding -----------------------------
505def _huffman_tree(__freq_map):
506 """_huffman_tree(dict, /) -> Node
508Given a dict mapping symbols to their frequency, construct a Huffman tree
509and return its root node.
510"""
511 from heapq import heappush, heappop
513 class Node(object):
514 """
515 There are to tyes of Node instances (both have 'freq' attribute):
516 * leaf node: has 'symbol' attribute
517 * parent node: has 'child' attribute (tuple with both children)
518 """
519 def __lt__(self, other):
520 # heapq needs to be able to compare the nodes
521 return self.freq < other.freq
523 minheap = []
524 # create all leaf nodes and push them onto the queue
525 for sym, f in __freq_map.items():
526 leaf = Node()
527 leaf.symbol = sym
528 leaf.freq = f
529 heappush(minheap, leaf)
531 # repeat the process until only one node remains
532 while len(minheap) > 1:
533 # take the two nodes with lowest frequencies from the queue
534 # to construct a new parent node and push it onto the queue
535 parent = Node()
536 parent.child = heappop(minheap), heappop(minheap)
537 parent.freq = parent.child[0].freq + parent.child[1].freq
538 heappush(minheap, parent)
540 # the single remaining node is the root of the Huffman tree
541 return minheap[0]
544def huffman_code(__freq_map, endian=None):
545 """huffman_code(dict, /, endian=None) -> dict
547Given a frequency map, a dictionary mapping symbols to their frequency,
548calculate the Huffman code, i.e. a dict mapping those symbols to
549bitarrays (with given bit-endianness). Note that the symbols are not limited
550to being strings. Symbols may be any hashable object.
551"""
552 if not isinstance(__freq_map, dict):
553 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)
555 if len(__freq_map) < 2:
556 if len(__freq_map) == 0:
557 raise ValueError("cannot create Huffman code with no symbols")
558 # Only one symbol: Normally if only one symbol is given, the code
559 # could be represented with zero bits. However here, the code should
560 # be at least one bit for the .encode() and .decode() methods to work.
561 # So we represent the symbol by a single code of length one, in
562 # particular one 0 bit. This is an incomplete code, since if a 1 bit
563 # is received, it has no meaning and will result in an error.
564 sym = list(__freq_map)[0]
565 return {sym: bitarray('0', endian)}
567 result = {}
569 def traverse(nd, prefix=bitarray(0, endian)):
570 try: # leaf
571 result[nd.symbol] = prefix
572 except AttributeError: # parent, so traverse each child
573 traverse(nd.child[0], prefix + '0')
574 traverse(nd.child[1], prefix + '1')
576 traverse(_huffman_tree(__freq_map))
577 return result
580def canonical_huffman(__freq_map):
581 """canonical_huffman(dict, /) -> tuple
583Given a frequency map, a dictionary mapping symbols to their frequency,
584calculate the canonical Huffman code. Returns a tuple containing:
5860. the canonical Huffman code as a dict mapping symbols to bitarrays
5871. a list containing the number of symbols of each code length
5882. a list of symbols in canonical order
590Note: the two lists may be used as input for `canonical_decode()`.
591"""
592 if not isinstance(__freq_map, dict):
593 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__)
595 if len(__freq_map) < 2:
596 if len(__freq_map) == 0:
597 raise ValueError("cannot create Huffman code with no symbols")
598 # Only one symbol: see note above in huffman_code()
599 sym = list(__freq_map)[0]
600 return {sym: bitarray('0', 'big')}, [0, 1], [sym]
602 code_length = {} # map symbols to their code length
604 def traverse(nd, length=0):
605 # traverse the Huffman tree, but (unlike in huffman_code() above) we
606 # now just simply record the length for reaching each symbol
607 try: # leaf
608 code_length[nd.symbol] = length
609 except AttributeError: # parent, so traverse each child
610 traverse(nd.child[0], length + 1)
611 traverse(nd.child[1], length + 1)
613 traverse(_huffman_tree(__freq_map))
615 # We now have a mapping of symbols to their code length, which is all we
616 # need to construct a list of tuples (symbol, code length) sorted by
617 # code length:
618 table = sorted(code_length.items(), key=lambda item: item[1])
620 maxbits = table[-1][1]
621 codedict = {}
622 count = (maxbits + 1) * [0]
624 code = 0
625 for i, (sym, length) in enumerate(table):
626 codedict[sym] = int2ba(code, length, 'big')
627 count[length] += 1
628 if i + 1 < len(table):
629 code += 1
630 code <<= table[i + 1][1] - length
632 return codedict, count, [item[0] for item in table]