Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/bitarray/util.py: 16%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

337 statements  

1# Copyright (c) 2019 - 2025, Ilan Schnell; All Rights Reserved 

2# bitarray is published under the PSF license. 

3# 

4# Author: Ilan Schnell 

5""" 

6Useful utilities for working with bitarrays. 

7""" 

8import os 

9import sys 

10import math 

11import random 

12 

13from bitarray import bitarray, bits2bytes 

14 

15from bitarray._util import ( 

16 zeros, ones, count_n, parity, _ssqi, xor_indices, 

17 count_and, count_or, count_xor, any_and, subset, 

18 correspond_all, byteswap, 

19 serialize, deserialize, 

20 ba2hex, hex2ba, 

21 ba2base, base2ba, 

22 sc_encode, sc_decode, 

23 vl_encode, vl_decode, 

24 canonical_decode, 

25) 

26 

27__all__ = [ 

28 'zeros', 'ones', 'urandom', 'random_k', 'random_p', 'gen_primes', 

29 'pprint', 'strip', 'count_n', 

30 'parity', 'sum_indices', 'xor_indices', 

31 'count_and', 'count_or', 'count_xor', 'any_and', 'subset', 

32 'correspond_all', 'byteswap', 'intervals', 

33 'ba2hex', 'hex2ba', 

34 'ba2base', 'base2ba', 

35 'ba2int', 'int2ba', 

36 'serialize', 'deserialize', 

37 'sc_encode', 'sc_decode', 

38 'vl_encode', 'vl_decode', 

39 'huffman_code', 'canonical_huffman', 'canonical_decode', 

40] 

41 

42 

43def urandom(__length, endian=None): 

44 """urandom(n, /, endian=None) -> bitarray 

45 

46Return random bitarray of length `n` (uses `os.urandom()`). 

47""" 

48 a = bitarray(os.urandom(bits2bytes(__length)), endian) 

49 del a[__length:] 

50 return a 

51 

52 

53def random_k(__n, k, endian=None): 

54 """random_k(n, /, k, endian=None) -> bitarray 

55 

56Return (pseudo-) random bitarray of length `n` with `k` elements 

57set to one. Mathematically equivalent to setting (in a bitarray of 

58length `n`) all bits at indices `random.sample(range(n), k)` to one. 

59The random bitarrays are reproducible when giving Python's `random.seed()` 

60a specific seed value. 

61""" 

62 r = _Random(__n, endian) 

63 if not isinstance(k, int): 

64 raise TypeError("int expected, got '%s'" % type(k).__name__) 

65 

66 return r.random_k(k) 

67 

68 

69def random_p(__n, p=0.5, endian=None): 

70 """random_p(n, /, p=0.5, endian=None) -> bitarray 

71 

72Return (pseudo-) random bitarray of length `n`, where each bit has 

73probability `p` of being one (independent of any other bits). Mathematically 

74equivalent to `bitarray((random() < p for _ in range(n)), endian)`, but much 

75faster for large `n`. The random bitarrays are reproducible when giving 

76Python's `random.seed()` with a specific seed value. 

77 

78This function requires Python 3.12 or higher, as it depends on the standard 

79library function `random.binomialvariate()`. Raises `NotImplementedError` 

80when Python version is too low. 

81""" 

82 if sys.version_info[:2] < (3, 12): 

83 raise NotImplementedError("bitarray.util.random_p() requires " 

84 "Python 3.12 or higher") 

85 r = _Random(__n, endian) 

86 return r.random_p(p) 

87 

88 

89class _Random: 

90 

91 # The main reason for this class it to enable testing functionality 

92 # individually in the test class Random_P_Tests in 'test_util.py'. 

93 # The test class also contains many comments and explanations. 

94 # To better understand how the algorithm works, see ./doc/random_p.rst 

95 # See also, VerificationTests in devel/test_random.py 

96 

97 # maximal number of calls to .random_half() in .combine() 

98 M = 8 

99 

100 # number of resulting probability intervals 

101 K = 1 << M 

102 

103 # limit for setting individual bits randomly 

104 SMALL_P = 0.01 

105 

106 def __init__(self, n=0, endian=None): 

107 self.n = n 

108 self.nbytes = bits2bytes(n) 

109 self.endian = endian 

110 

111 def random_half(self): 

112 """ 

113 Return bitarray with each bit having probability p = 1/2 of being 1. 

114 """ 

115 nbytes = self.nbytes 

116 # use random module function for reproducibility (not urandom()) 

117 b = random.getrandbits(8 * nbytes).to_bytes(nbytes, 'little') 

118 a = bitarray(b, self.endian) 

119 del a[self.n:] 

120 return a 

121 

122 def op_seq(self, i): 

123 """ 

124 Return bitarray containing operator sequence. 

125 Each item represents a bitwise operation: 0: AND 1: OR 

126 After applying the sequence (see .combine_half()), we 

127 obtain a bitarray with probability q = i / K 

128 """ 

129 if not 0 < i < self.K: 

130 raise ValueError("0 < i < %d, got i = %d" % (self.K, i)) 

131 

132 # sequence of &, | operations - least significant operations first 

133 a = bitarray(i.to_bytes(2, byteorder="little"), "little") 

134 return a[a.index(1) + 1 : self.M] 

135 

136 def combine_half(self, seq): 

137 """ 

138 Combine random bitarrays with probability 1/2 

139 according to given operator sequence. 

140 """ 

141 a = self.random_half() 

142 for k in seq: 

143 if k: 

144 a |= self.random_half() 

145 else: 

146 a &= self.random_half() 

147 return a 

148 

149 def random_k(self, k): 

150 n = self.n 

151 # error check inputs and handle edge cases 

152 if k <= 0 or k >= n: 

153 if k == 0: 

154 return zeros(n, self.endian) 

155 if k == n: 

156 return ones(n, self.endian) 

157 raise ValueError("k must be in range 0 <= k <= n, got %s" % k) 

158 

159 # exploit symmetry to establish: k <= n // 2 

160 if k > n // 2: 

161 a = self.random_k(n - k) 

162 a.invert() # use in-place to avoid copying 

163 return a 

164 

165 # decide on sequence, see VerificationTests devel/test_random.py 

166 if k < 16 or k * self.K < 3 * n: 

167 i = 0 

168 else: 

169 p = k / n # p <= 0.5 

170 p -= (0.2 - 0.4 * p) / math.sqrt(n) 

171 i = int(p * (self.K + 1)) 

172 

173 # combine random bitarrays using bitwise AND and OR operations 

174 if i < 3: 

175 a = zeros(n, self.endian) 

176 diff = -k 

177 else: 

178 a = self.combine_half(self.op_seq(i)) 

179 diff = a.count() - k 

180 

181 randrange = random.randrange 

182 if diff < 0: # not enough bits 1 - increase count 

183 for _ in range(-diff): 

184 i = randrange(n) 

185 while a[i]: 

186 i = randrange(n) 

187 a[i] = 1 

188 elif diff > 0: # too many bits 1 - decrease count 

189 for _ in range(diff): 

190 i = randrange(n) 

191 while not a[i]: 

192 i = randrange(n) 

193 a[i] = 0 

194 

195 return a 

196 

197 def random_p(self, p): 

198 # error check inputs and handle edge cases 

199 if p <= 0.0 or p == 0.5 or p >= 1.0: 

200 if p == 0.0: 

201 return zeros(self.n, self.endian) 

202 if p == 0.5: 

203 return self.random_half() 

204 if p == 1.0: 

205 return ones(self.n, self.endian) 

206 raise ValueError("p must be in range 0.0 <= p <= 1.0, got %s" % p) 

207 

208 # for small n, use literal definition 

209 if self.n < 16: 

210 return bitarray((random.random() < p for _ in range(self.n)), 

211 self.endian) 

212 

213 # exploit symmetry to establish: p < 0.5 

214 if p > 0.5: 

215 a = self.random_p(1.0 - p) 

216 a.invert() # use in-place to avoid copying 

217 return a 

218 

219 # for small p, set randomly individual bits 

220 if p < self.SMALL_P: 

221 return self.random_k(random.binomialvariate(self.n, p)) 

222 

223 # calculate operator sequence 

224 i = int(p * self.K) 

225 if p * (self.K + 1) > i + 1: # see devel/test_random.py 

226 i += 1 

227 seq = self.op_seq(i) 

228 q = i / self.K 

229 

230 # when n is small compared to number of operations, also use literal 

231 if self.n < 100 and self.nbytes <= len(seq) + 3 * bool(q != p): 

232 return bitarray((random.random() < p for _ in range(self.n)), 

233 self.endian) 

234 

235 # combine random bitarrays using bitwise AND and OR operations 

236 a = self.combine_half(seq) 

237 if q < p: 

238 x = (p - q) / (1.0 - q) 

239 a |= self.random_p(x) 

240 elif q > p: 

241 x = p / q 

242 a &= self.random_p(x) 

243 

244 return a 

245 

246 

247def gen_primes(__n, endian=None, odd=False): 

248 """gen_primes(n, /, endian=None, odd=False) -> bitarray 

249 

250Generate a bitarray of length `n` in which active indices are prime numbers. 

251By default (`odd=False`), active indices correspond to prime numbers directly. 

252When `odd=True`, only odd prime numbers are represented in the resulting 

253bitarray `a`, and `a[i]` corresponds to `2*i+1` being prime or not. 

254""" 

255 n = int(__n) 

256 if n < 0: 

257 raise ValueError("bitarray length must be >= 0") 

258 

259 if odd: 

260 a = ones(105, endian) # 105 = 3 * 5 * 7 

261 a[1::3] = 0 

262 a[2::5] = 0 

263 a[3::7] = 0 

264 f = "01110110" 

265 else: 

266 a = ones(210, endian) # 210 = 2 * 3 * 5 * 7 

267 for i in 2, 3, 5, 7: 

268 a[::i] = 0 

269 f = "00110101" 

270 

271 # repeating the array many times is faster than setting the multiples 

272 # of the low primes to 0 

273 a *= (n + len(a) - 1) // len(a) 

274 a[:8] = bitarray(f, endian) 

275 del a[n:] 

276 # perform sieve starting at 11 

277 if odd: 

278 for i in a.search(1, 5, int(math.sqrt(n // 2) + 1.0)): # 11//2 = 5 

279 j = 2 * i + 1 

280 a[(j * j) // 2 :: j] = 0 

281 else: 

282 # i*i is always odd, and even bits are already set to 0: use step 2*i 

283 for i in a.search(1, 11, int(math.sqrt(n) + 1.0)): 

284 a[i * i :: 2 * i] = 0 

285 return a 

286 

287 

288def sum_indices(__a, mode=1): 

289 """sum_indices(a, /, mode=1) -> int 

290 

291Return sum of indices of all active bits in bitarray `a`. 

292Equivalent to `sum(i for i, v in enumerate(a) if v)`. 

293`mode=2` sums square of indices. 

294""" 

295 if mode not in (1, 2): 

296 raise ValueError("unexpected mode %r" % mode) 

297 

298 # For details see: devel/test_sum_indices.py 

299 n = 1 << 19 # block size 512 Kbits 

300 if len(__a) <= n: # shortcut for single block 

301 return _ssqi(__a, mode) 

302 

303 # Constants 

304 m = n // 8 # block size in bytes 

305 o1 = n * (n - 1) // 2 

306 o2 = o1 * (2 * n - 1) // 3 

307 

308 nblocks = (len(__a) + n - 1) // n 

309 padbits = __a.padbits 

310 sm = 0 

311 for i in range(nblocks): 

312 # use memoryview to avoid copying memory 

313 v = memoryview(__a)[i * m : (i + 1) * m] 

314 block = bitarray(None, __a.endian, buffer=v) 

315 if padbits and i == nblocks - 1: 

316 if block.readonly: 

317 block = bitarray(block) 

318 block[-padbits:] = 0 

319 

320 k = block.count() 

321 if k: 

322 y = n * i 

323 z1 = o1 if k == n else _ssqi(block) 

324 if mode == 1: 

325 sm += k * y + z1 

326 else: 

327 z2 = o2 if k == n else _ssqi(block, 2) 

328 sm += (k * y + 2 * z1) * y + z2 

329 

330 return sm 

331 

332 

333def pprint(__a, stream=None, group=8, indent=4, width=80): 

334 """pprint(bitarray, /, stream=None, group=8, indent=4, width=80) 

335 

336Pretty-print bitarray object to `stream`, defaults is `sys.stdout`. 

337By default, bits are grouped in bytes (8 bits), and 64 bits per line. 

338Non-bitarray objects are printed using `pprint.pprint()`. 

339""" 

340 if stream is None: 

341 stream = sys.stdout 

342 

343 if not isinstance(__a, bitarray): 

344 import pprint as _pprint 

345 _pprint.pprint(__a, stream=stream, indent=indent, width=width) 

346 return 

347 

348 group = int(group) 

349 if group < 1: 

350 raise ValueError('group must be >= 1') 

351 indent = int(indent) 

352 if indent < 0: 

353 raise ValueError('indent must be >= 0') 

354 width = int(width) 

355 if width <= indent: 

356 raise ValueError('width must be > %d (indent)' % indent) 

357 

358 gpl = (width - indent) // (group + 1) # groups per line 

359 epl = group * gpl # elements per line 

360 if epl == 0: 

361 epl = width - indent - 2 

362 type_name = type(__a).__name__ 

363 # here 4 is len("'()'") 

364 multiline = len(type_name) + 4 + len(__a) + len(__a) // group >= width 

365 if multiline: 

366 quotes = "'''" 

367 elif __a: 

368 quotes = "'" 

369 else: 

370 quotes = "" 

371 

372 stream.write("%s(%s" % (type_name, quotes)) 

373 for i, b in enumerate(__a): 

374 if multiline and i % epl == 0: 

375 stream.write('\n%s' % (indent * ' ')) 

376 if i % group == 0 and i % epl != 0: 

377 stream.write(' ') 

378 stream.write(str(b)) 

379 

380 if multiline: 

381 stream.write('\n') 

382 

383 stream.write("%s)\n" % quotes) 

384 stream.flush() 

385 

386 

387def strip(__a, mode='right'): 

388 """strip(bitarray, /, mode='right') -> bitarray 

389 

390Return a new bitarray with zeros stripped from left, right or both ends. 

391Allowed values for mode are the strings: `left`, `right`, `both` 

392""" 

393 if not isinstance(mode, str): 

394 raise TypeError("str expected for mode, got '%s'" % 

395 type(__a).__name__) 

396 if mode not in ('left', 'right', 'both'): 

397 raise ValueError("mode must be 'left', 'right' or 'both', got %r" % 

398 mode) 

399 

400 start = None if mode == 'right' else __a.find(1) 

401 if start == -1: 

402 return __a[:0] 

403 stop = None if mode == 'left' else __a.find(1, right=1) + 1 

404 return __a[start:stop] 

405 

406 

407def intervals(__a): 

408 """intervals(bitarray, /) -> iterator 

409 

410Compute all uninterrupted intervals of 1s and 0s, and return an 

411iterator over tuples `(value, start, stop)`. The intervals are guaranteed 

412to be in order, and their size is always non-zero (`stop - start > 0`). 

413""" 

414 try: 

415 value = __a[0] # value of current interval 

416 except IndexError: 

417 return 

418 n = len(__a) 

419 stop = 0 # "previous" stop - becomes next start 

420 

421 while stop < n: 

422 start = stop 

423 # assert __a[start] == value 

424 try: # find next occurrence of opposite value 

425 stop = __a.index(not value, start) 

426 except ValueError: 

427 stop = n 

428 yield int(value), start, stop 

429 value = not value # next interval has opposite value 

430 

431 

432def ba2int(__a, signed=False): 

433 """ba2int(bitarray, /, signed=False) -> int 

434 

435Convert the given bitarray to an integer. 

436The bit-endianness of the bitarray is respected. 

437`signed` indicates whether two's complement is used to represent the integer. 

438""" 

439 if not isinstance(__a, bitarray): 

440 raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) 

441 length = len(__a) 

442 if length == 0: 

443 raise ValueError("non-empty bitarray expected") 

444 

445 if __a.padbits: 

446 pad = zeros(__a.padbits, __a.endian) 

447 __a = __a + pad if __a.endian == "little" else pad + __a 

448 

449 res = int.from_bytes(__a.tobytes(), byteorder=__a.endian) 

450 

451 if signed and res >> length - 1: 

452 res -= 1 << length 

453 return res 

454 

455 

456def int2ba(__i, length=None, endian=None, signed=False): 

457 """int2ba(int, /, length=None, endian=None, signed=False) -> bitarray 

458 

459Convert the given integer to a bitarray (with given bit-endianness, 

460and no leading (big-endian) / trailing (little-endian) zeros), unless 

461the `length` of the bitarray is provided. An `OverflowError` is raised 

462if the integer is not representable with the given number of bits. 

463`signed` determines whether two's complement is used to represent the integer, 

464and requires `length` to be provided. 

465""" 

466 if not isinstance(__i, int): 

467 raise TypeError("int expected, got '%s'" % type(__i).__name__) 

468 if length is not None: 

469 if not isinstance(length, int): 

470 raise TypeError("int expected for argument 'length'") 

471 if length <= 0: 

472 raise ValueError("length must be > 0") 

473 

474 if signed: 

475 if length is None: 

476 raise TypeError("signed requires argument 'length'") 

477 m = 1 << length - 1 

478 if not (-m <= __i < m): 

479 raise OverflowError("signed integer not in range(%d, %d), " 

480 "got %d" % (-m, m, __i)) 

481 if __i < 0: 

482 __i += 1 << length 

483 else: # unsigned 

484 if length and __i >> length: 

485 raise OverflowError("unsigned integer not in range(0, %d), " 

486 "got %d" % (1 << length, __i)) 

487 

488 a = bitarray(0, endian) 

489 b = __i.to_bytes(bits2bytes(__i.bit_length()), byteorder=a.endian) 

490 a.frombytes(b) 

491 le = a.endian == 'little' 

492 if length is None: 

493 return strip(a, 'right' if le else 'left') if a else a + '0' 

494 

495 if len(a) > length: 

496 return a[:length] if le else a[-length:] 

497 if len(a) == length: 

498 return a 

499 # len(a) < length, we need padding 

500 pad = zeros(length - len(a), a.endian) 

501 return a + pad if le else pad + a 

502 

503# ------------------------------ Huffman coding ----------------------------- 

504 

505def _huffman_tree(__freq_map): 

506 """_huffman_tree(dict, /) -> Node 

507 

508Given a dict mapping symbols to their frequency, construct a Huffman tree 

509and return its root node. 

510""" 

511 from heapq import heappush, heappop 

512 

513 class Node(object): 

514 """ 

515 There are to tyes of Node instances (both have 'freq' attribute): 

516 * leaf node: has 'symbol' attribute 

517 * parent node: has 'child' attribute (tuple with both children) 

518 """ 

519 def __lt__(self, other): 

520 # heapq needs to be able to compare the nodes 

521 return self.freq < other.freq 

522 

523 minheap = [] 

524 # create all leaf nodes and push them onto the queue 

525 for sym, f in __freq_map.items(): 

526 leaf = Node() 

527 leaf.symbol = sym 

528 leaf.freq = f 

529 heappush(minheap, leaf) 

530 

531 # repeat the process until only one node remains 

532 while len(minheap) > 1: 

533 # take the two nodes with lowest frequencies from the queue 

534 # to construct a new parent node and push it onto the queue 

535 parent = Node() 

536 parent.child = heappop(minheap), heappop(minheap) 

537 parent.freq = parent.child[0].freq + parent.child[1].freq 

538 heappush(minheap, parent) 

539 

540 # the single remaining node is the root of the Huffman tree 

541 return minheap[0] 

542 

543 

544def huffman_code(__freq_map, endian=None): 

545 """huffman_code(dict, /, endian=None) -> dict 

546 

547Given a frequency map, a dictionary mapping symbols to their frequency, 

548calculate the Huffman code, i.e. a dict mapping those symbols to 

549bitarrays (with given bit-endianness). Note that the symbols are not limited 

550to being strings. Symbols may be any hashable object. 

551""" 

552 if not isinstance(__freq_map, dict): 

553 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) 

554 

555 if len(__freq_map) < 2: 

556 if len(__freq_map) == 0: 

557 raise ValueError("cannot create Huffman code with no symbols") 

558 # Only one symbol: Normally if only one symbol is given, the code 

559 # could be represented with zero bits. However here, the code should 

560 # be at least one bit for the .encode() and .decode() methods to work. 

561 # So we represent the symbol by a single code of length one, in 

562 # particular one 0 bit. This is an incomplete code, since if a 1 bit 

563 # is received, it has no meaning and will result in an error. 

564 sym = list(__freq_map)[0] 

565 return {sym: bitarray('0', endian)} 

566 

567 result = {} 

568 

569 def traverse(nd, prefix=bitarray(0, endian)): 

570 try: # leaf 

571 result[nd.symbol] = prefix 

572 except AttributeError: # parent, so traverse each child 

573 traverse(nd.child[0], prefix + '0') 

574 traverse(nd.child[1], prefix + '1') 

575 

576 traverse(_huffman_tree(__freq_map)) 

577 return result 

578 

579 

580def canonical_huffman(__freq_map): 

581 """canonical_huffman(dict, /) -> tuple 

582 

583Given a frequency map, a dictionary mapping symbols to their frequency, 

584calculate the canonical Huffman code. Returns a tuple containing: 

585 

5860. the canonical Huffman code as a dict mapping symbols to bitarrays 

5871. a list containing the number of symbols of each code length 

5882. a list of symbols in canonical order 

589 

590Note: the two lists may be used as input for `canonical_decode()`. 

591""" 

592 if not isinstance(__freq_map, dict): 

593 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) 

594 

595 if len(__freq_map) < 2: 

596 if len(__freq_map) == 0: 

597 raise ValueError("cannot create Huffman code with no symbols") 

598 # Only one symbol: see note above in huffman_code() 

599 sym = list(__freq_map)[0] 

600 return {sym: bitarray('0', 'big')}, [0, 1], [sym] 

601 

602 code_length = {} # map symbols to their code length 

603 

604 def traverse(nd, length=0): 

605 # traverse the Huffman tree, but (unlike in huffman_code() above) we 

606 # now just simply record the length for reaching each symbol 

607 try: # leaf 

608 code_length[nd.symbol] = length 

609 except AttributeError: # parent, so traverse each child 

610 traverse(nd.child[0], length + 1) 

611 traverse(nd.child[1], length + 1) 

612 

613 traverse(_huffman_tree(__freq_map)) 

614 

615 # We now have a mapping of symbols to their code length, which is all we 

616 # need to construct a list of tuples (symbol, code length) sorted by 

617 # code length: 

618 table = sorted(code_length.items(), key=lambda item: item[1]) 

619 

620 maxbits = table[-1][1] 

621 codedict = {} 

622 count = (maxbits + 1) * [0] 

623 

624 code = 0 

625 for i, (sym, length) in enumerate(table): 

626 codedict[sym] = int2ba(code, length, 'big') 

627 count[length] += 1 

628 if i + 1 < len(table): 

629 code += 1 

630 code <<= table[i + 1][1] - length 

631 

632 return codedict, count, [item[0] for item in table]