Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/bitarray/util.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

213 statements  

1# Copyright (c) 2019 - 2024, Ilan Schnell; All Rights Reserved 

2# bitarray is published under the PSF license. 

3# 

4# Author: Ilan Schnell 

5""" 

6Useful utilities for working with bitarrays. 

7""" 

8from __future__ import absolute_import 

9 

10import os 

11import sys 

12 

13from bitarray import bitarray, bits2bytes 

14 

15from bitarray._util import ( 

16 zeros, ones, count_n, parity, 

17 count_and, count_or, count_xor, any_and, subset, 

18 _correspond_all, 

19 serialize, deserialize, 

20 ba2hex, hex2ba, 

21 ba2base, base2ba, 

22 sc_encode, sc_decode, 

23 vl_encode, vl_decode, 

24 canonical_decode, 

25) 

26 

27__all__ = [ 

28 'zeros', 'ones', 'urandom', 

29 'pprint', 'make_endian', 'rindex', 'strip', 'count_n', 

30 'parity', 'count_and', 'count_or', 'count_xor', 'any_and', 'subset', 

31 'intervals', 

32 'ba2hex', 'hex2ba', 

33 'ba2base', 'base2ba', 

34 'ba2int', 'int2ba', 

35 'serialize', 'deserialize', 

36 'sc_encode', 'sc_decode', 

37 'vl_encode', 'vl_decode', 

38 'huffman_code', 'canonical_huffman', 'canonical_decode', 

39] 

40 

41 

42_is_py2 = bool(sys.version_info[0] == 2) 

43 

44 

45def urandom(__length, endian=None): 

46 """urandom(length, /, endian=None) -> bitarray 

47 

48Return a bitarray of `length` random bits (uses `os.urandom`). 

49""" 

50 a = bitarray(0, endian) 

51 a.frombytes(os.urandom(bits2bytes(__length))) 

52 del a[__length:] 

53 return a 

54 

55 

56def rindex(__a, __sub_bitarray=1, __start=0, __stop=sys.maxsize): 

57 """rindex(bitarray, sub_bitarray=1, start=0, stop=<end>, /) -> int 

58 

59Return rightmost (highest) index where sub_bitarray (or item - defaults 

60to 1) is found in bitarray (`a`), such that sub_bitarray is contained 

61within `a[start:stop]`. 

62Raises `ValueError` when the sub_bitarray is not present. 

63""" 

64 from warnings import warn 

65 

66 warn("rindex() is deprecated and will be removed in bitarray 3.0 - " 

67 "use .index(..., right=True) method instead.", 

68 DeprecationWarning, stacklevel=1) 

69 

70 if not isinstance(__a, bitarray): 

71 raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) 

72 

73 return __a.index(__sub_bitarray, __start, __stop, right=True) 

74 

75 

76def pprint(__a, stream=None, group=8, indent=4, width=80): 

77 """pprint(bitarray, /, stream=None, group=8, indent=4, width=80) 

78 

79Prints the formatted representation of object on `stream` (which defaults 

80to `sys.stdout`). By default, elements are grouped in bytes (8 elements), 

81and 8 bytes (64 elements) per line. 

82Non-bitarray objects are printed by the standard library 

83function `pprint.pprint()`. 

84""" 

85 if stream is None: 

86 stream = sys.stdout 

87 

88 if not isinstance(__a, bitarray): 

89 import pprint as _pprint 

90 _pprint.pprint(__a, stream=stream, indent=indent, width=width) 

91 return 

92 

93 group = int(group) 

94 if group < 1: 

95 raise ValueError('group must be >= 1') 

96 indent = int(indent) 

97 if indent < 0: 

98 raise ValueError('indent must be >= 0') 

99 width = int(width) 

100 if width <= indent: 

101 raise ValueError('width must be > %d (indent)' % indent) 

102 

103 gpl = (width - indent) // (group + 1) # groups per line 

104 epl = group * gpl # elements per line 

105 if epl == 0: 

106 epl = width - indent - 2 

107 type_name = type(__a).__name__ 

108 # here 4 is len("'()'") 

109 multiline = len(type_name) + 4 + len(__a) + len(__a) // group >= width 

110 if multiline: 

111 quotes = "'''" 

112 elif __a: 

113 quotes = "'" 

114 else: 

115 quotes = "" 

116 

117 stream.write("%s(%s" % (type_name, quotes)) 

118 for i, b in enumerate(__a): 

119 if multiline and i % epl == 0: 

120 stream.write('\n%s' % (indent * ' ')) 

121 if i % group == 0 and i % epl != 0: 

122 stream.write(' ') 

123 stream.write(str(b)) 

124 

125 if multiline: 

126 stream.write('\n') 

127 

128 stream.write("%s)\n" % quotes) 

129 stream.flush() 

130 

131 

132def make_endian(__a, endian): 

133 """make_endian(bitarray, /, endian) -> bitarray 

134 

135When the endianness of the given bitarray is different from `endian`, 

136return a new bitarray, with endianness `endian` and the same elements 

137as the original bitarray. 

138Otherwise (endianness is already `endian`) the original bitarray is returned 

139unchanged. 

140""" 

141 from warnings import warn 

142 

143 warn("make_endian() is deprecated and will be removed in bitarray 3.0 - " 

144 "use bitarray(..., endian=...) instead", 

145 DeprecationWarning, stacklevel=1) 

146 

147 if not isinstance(__a, bitarray): 

148 raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) 

149 

150 if __a.endian() == endian: 

151 return __a 

152 

153 return bitarray(__a, endian) 

154 

155 

156def strip(__a, mode='right'): 

157 """strip(bitarray, /, mode='right') -> bitarray 

158 

159Return a new bitarray with zeros stripped from left, right or both ends. 

160Allowed values for mode are the strings: `left`, `right`, `both` 

161""" 

162 if not isinstance(mode, str): 

163 raise TypeError("str expected for mode, got '%s'" % type(__a).__name__) 

164 if mode not in ('left', 'right', 'both'): 

165 raise ValueError("mode must be 'left', 'right' or 'both', got %r" % 

166 mode) 

167 

168 start = None if mode == 'right' else __a.find(1) 

169 if start == -1: 

170 return __a[:0] 

171 stop = None if mode == 'left' else __a.find(1, right=1) + 1 

172 return __a[start:stop] 

173 

174 

175def intervals(__a): 

176 """intervals(bitarray, /) -> iterator 

177 

178Compute all uninterrupted intervals of 1s and 0s, and return an 

179iterator over tuples `(value, start, stop)`. The intervals are guaranteed 

180to be in order, and their size is always non-zero (`stop - start > 0`). 

181""" 

182 try: 

183 value = __a[0] # value of current interval 

184 except IndexError: 

185 return 

186 n = len(__a) 

187 stop = 0 # "previous" stop - becomes next start 

188 

189 while stop < n: 

190 start = stop 

191 # assert __a[start] == value 

192 try: # find next occurrence of opposite value 

193 stop = __a.index(not value, start) 

194 except ValueError: 

195 stop = n 

196 yield int(value), start, stop 

197 value = not value # next interval has opposite value 

198 

199 

200def ba2int(__a, signed=False): 

201 """ba2int(bitarray, /, signed=False) -> int 

202 

203Convert the given bitarray to an integer. 

204The bit-endianness of the bitarray is respected. 

205`signed` indicates whether two's complement is used to represent the integer. 

206""" 

207 if not isinstance(__a, bitarray): 

208 raise TypeError("bitarray expected, got '%s'" % type(__a).__name__) 

209 length = len(__a) 

210 if length == 0: 

211 raise ValueError("non-empty bitarray expected") 

212 

213 le = bool(__a.endian() == 'little') 

214 if __a.padbits: 

215 pad = zeros(__a.padbits, __a.endian()) 

216 __a = __a + pad if le else pad + __a 

217 

218 if _is_py2: 

219 a = bitarray(__a, 'big') 

220 if le: 

221 a.reverse() 

222 res = int(ba2hex(a), 16) 

223 else: # py3 

224 res = int.from_bytes(__a.tobytes(), byteorder=__a.endian()) 

225 

226 if signed and res >= 1 << (length - 1): 

227 res -= 1 << length 

228 return res 

229 

230 

231def int2ba(__i, length=None, endian=None, signed=False): 

232 """int2ba(int, /, length=None, endian=None, signed=False) -> bitarray 

233 

234Convert the given integer to a bitarray (with given endianness, 

235and no leading (big-endian) / trailing (little-endian) zeros), unless 

236the `length` of the bitarray is provided. An `OverflowError` is raised 

237if the integer is not representable with the given number of bits. 

238`signed` determines whether two's complement is used to represent the integer, 

239and requires `length` to be provided. 

240""" 

241 if not isinstance(__i, (int, long) if _is_py2 else int): 

242 raise TypeError("int expected, got '%s'" % type(__i).__name__) 

243 if length is not None: 

244 if not isinstance(length, int): 

245 raise TypeError("int expected for length") 

246 if length <= 0: 

247 raise ValueError("length must be > 0") 

248 if signed and length is None: 

249 raise TypeError("signed requires length") 

250 

251 if __i == 0: 

252 # there are special cases for 0 which we'd rather not deal with below 

253 return zeros(length or 1, endian) 

254 

255 if signed: 

256 m = 1 << (length - 1) 

257 if not (-m <= __i < m): 

258 raise OverflowError("signed integer not in range(%d, %d), " 

259 "got %d" % (-m, m, __i)) 

260 if __i < 0: 

261 __i += 1 << length 

262 else: # unsigned 

263 if __i < 0: 

264 raise OverflowError("unsigned integer not positive, got %d" % __i) 

265 if length and __i >= (1 << length): 

266 raise OverflowError("unsigned integer not in range(0, %d), " 

267 "got %d" % (1 << length, __i)) 

268 

269 a = bitarray(0, endian) 

270 le = bool(a.endian() == 'little') 

271 if _is_py2: 

272 s = hex(__i)[2:].rstrip('L') 

273 a.extend(hex2ba(s, 'big')) 

274 if le: 

275 a.reverse() 

276 else: # py3 

277 b = __i.to_bytes(bits2bytes(__i.bit_length()), byteorder=a.endian()) 

278 a.frombytes(b) 

279 

280 if length is None: 

281 return strip(a, 'right' if le else 'left') 

282 

283 la = len(a) 

284 if la > length: 

285 a = a[:length] if le else a[-length:] 

286 if la < length: 

287 pad = zeros(length - la, a.endian()) 

288 a = a + pad if le else pad + a 

289 assert len(a) == length 

290 return a 

291 

292# ------------------------------ Huffman coding ----------------------------- 

293 

294def _huffman_tree(__freq_map): 

295 """_huffman_tree(dict, /) -> Node 

296 

297Given a dict mapping symbols to their frequency, construct a Huffman tree 

298and return its root node. 

299""" 

300 from heapq import heappush, heappop 

301 

302 class Node(object): 

303 """ 

304 A Node instance will either have a 'symbol' (leaf node) or 

305 a 'child' (a tuple with both children) attribute. 

306 The 'freq' attribute will always be present. 

307 """ 

308 def __lt__(self, other): 

309 # heapq needs to be able to compare the nodes 

310 return self.freq < other.freq 

311 

312 minheap = [] 

313 # create all leaf nodes and push them onto the queue 

314 for sym, f in __freq_map.items(): 

315 leaf = Node() 

316 leaf.symbol = sym 

317 leaf.freq = f 

318 heappush(minheap, leaf) 

319 

320 # repeat the process until only one node remains 

321 while len(minheap) > 1: 

322 # take the two nodes with lowest frequencies from the queue 

323 # to construct a new node and push it onto the queue 

324 parent = Node() 

325 parent.child = heappop(minheap), heappop(minheap) 

326 parent.freq = parent.child[0].freq + parent.child[1].freq 

327 heappush(minheap, parent) 

328 

329 # the single remaining node is the root of the Huffman tree 

330 return minheap[0] 

331 

332 

333def huffman_code(__freq_map, endian=None): 

334 """huffman_code(dict, /, endian=None) -> dict 

335 

336Given a frequency map, a dictionary mapping symbols to their frequency, 

337calculate the Huffman code, i.e. a dict mapping those symbols to 

338bitarrays (with given endianness). Note that the symbols are not limited 

339to being strings. Symbols may may be any hashable object (such as `None`). 

340""" 

341 if not isinstance(__freq_map, dict): 

342 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) 

343 

344 b0 = bitarray('0', endian) 

345 b1 = bitarray('1', endian) 

346 

347 if len(__freq_map) < 2: 

348 if len(__freq_map) == 0: 

349 raise ValueError("cannot create Huffman code with no symbols") 

350 # Only one symbol: Normally if only one symbol is given, the code 

351 # could be represented with zero bits. However here, the code should 

352 # be at least one bit for the .encode() and .decode() methods to work. 

353 # So we represent the symbol by a single code of length one, in 

354 # particular one 0 bit. This is an incomplete code, since if a 1 bit 

355 # is received, it has no meaning and will result in an error. 

356 return {list(__freq_map)[0]: b0} 

357 

358 result = {} 

359 

360 def traverse(nd, prefix=bitarray(0, endian)): 

361 try: # leaf 

362 result[nd.symbol] = prefix 

363 except AttributeError: # parent, so traverse each of the children 

364 traverse(nd.child[0], prefix + b0) 

365 traverse(nd.child[1], prefix + b1) 

366 

367 traverse(_huffman_tree(__freq_map)) 

368 return result 

369 

370 

371def canonical_huffman(__freq_map): 

372 """canonical_huffman(dict, /) -> tuple 

373 

374Given a frequency map, a dictionary mapping symbols to their frequency, 

375calculate the canonical Huffman code. Returns a tuple containing: 

376 

3770. the canonical Huffman code as a dict mapping symbols to bitarrays 

3781. a list containing the number of symbols of each code length 

3792. a list of symbols in canonical order 

380 

381Note: the two lists may be used as input for `canonical_decode()`. 

382""" 

383 if not isinstance(__freq_map, dict): 

384 raise TypeError("dict expected, got '%s'" % type(__freq_map).__name__) 

385 

386 if len(__freq_map) < 2: 

387 if len(__freq_map) == 0: 

388 raise ValueError("cannot create Huffman code with no symbols") 

389 # Only one symbol: see note above in huffman_code() 

390 sym = list(__freq_map)[0] 

391 return {sym: bitarray('0', 'big')}, [0, 1], [sym] 

392 

393 code_length = {} # map symbols to their code length 

394 

395 def traverse(nd, length=0): 

396 # traverse the Huffman tree, but (unlike in huffman_code() above) we 

397 # now just simply record the length for reaching each symbol 

398 try: # leaf 

399 code_length[nd.symbol] = length 

400 except AttributeError: # parent, so traverse each of the children 

401 traverse(nd.child[0], length + 1) 

402 traverse(nd.child[1], length + 1) 

403 

404 traverse(_huffman_tree(__freq_map)) 

405 

406 # we now have a mapping of symbols to their code length, 

407 # which is all we need 

408 

409 table = sorted(code_length.items(), key=lambda item: (item[1], item[0])) 

410 

411 maxbits = max(item[1] for item in table) 

412 codedict = {} 

413 count = (maxbits + 1) * [0] 

414 

415 code = 0 

416 for i, (sym, length) in enumerate(table): 

417 codedict[sym] = int2ba(code, length, 'big') 

418 count[length] += 1 

419 if i + 1 < len(table): 

420 code += 1 

421 code <<= table[i + 1][1] - length 

422 

423 return codedict, count, [item[0] for item in table]