1from __future__ import annotations
2
3import struct
4import math
5from typing import Union, Dict, Callable, Optional
6import functools
7import bitstring
8from bitstring.fp8 import p4binary_fmt, p3binary_fmt
9from bitstring.mxfp import (e3m2mxfp_fmt, e2m3mxfp_fmt, e2m1mxfp_fmt, e4m3mxfp_saturate_fmt,
10 e5m2mxfp_saturate_fmt, e4m3mxfp_overflow_fmt, e5m2mxfp_overflow_fmt)
11
12helpers = bitstring.bitstore_helpers
13ConstBitStore = bitstring.bitstore.ConstBitStore
14MutableBitStore = bitstring.bitstore.MutableBitStore
15
16
17CACHE_SIZE = 256
18
19@functools.lru_cache(CACHE_SIZE)
20def str_to_bitstore(s: str) -> ConstBitStore:
21 _, tokens = bitstring.utils.tokenparser(s)
22 constbitstores = [bitstore_from_token(*token) for token in tokens]
23 return ConstBitStore.join(constbitstores)
24
25
26literal_bit_funcs: Dict[str, Callable[..., ConstBitStore]] = {
27 '0x': helpers.hex2bitstore,
28 '0X': helpers.hex2bitstore,
29 '0b': helpers.bin2bitstore,
30 '0B': helpers.bin2bitstore,
31 '0o': helpers.oct2bitstore,
32 '0O': helpers.oct2bitstore,
33}
34
35
36def bitstore_from_token(name: str, token_length: Optional[int], value: Optional[str]) -> ConstBitStore:
37 if name in literal_bit_funcs:
38 return literal_bit_funcs[name](value)
39 try:
40 d = bitstring.dtypes.Dtype(name, token_length)
41 except ValueError as e:
42 raise bitstring.CreationError(f"Can't parse token: {e}")
43 if value is None and name != 'pad':
44 raise ValueError(f"Token {name} requires a value.")
45 bs = d.build(value)._bitstore
46 if token_length is not None and len(bs) != d.bitlength:
47 raise bitstring.CreationError(f"Token with length {token_length} packed with value of length {len(bs)} "
48 f"({name}:{token_length}={value}).")
49 return bs
50
51
52
53def ue2bitstore(i: Union[str, int]) -> ConstBitStore:
54 i = int(i)
55 if i < 0:
56 raise bitstring.CreationError("Cannot use negative initialiser for unsigned exponential-Golomb.")
57 if i == 0:
58 return ConstBitStore.from_bin('1')
59 tmp = i + 1
60 leadingzeros = -1
61 while tmp > 0:
62 tmp >>= 1
63 leadingzeros += 1
64 remainingpart = i + 1 - (1 << leadingzeros)
65 return ConstBitStore.from_bin('0' * leadingzeros + '1') + helpers.int2bitstore(remainingpart, leadingzeros, False)
66
67
68def se2bitstore(i: Union[str, int]) -> ConstBitStore:
69 i = int(i)
70 if i > 0:
71 u = (i * 2) - 1
72 else:
73 u = -2 * i
74 return ue2bitstore(u)
75
76
77def uie2bitstore(i: Union[str, int]) -> ConstBitStore:
78 i = int(i)
79 if i < 0:
80 raise bitstring.CreationError("Cannot use negative initialiser for unsigned interleaved exponential-Golomb.")
81 return ConstBitStore.from_bin('1' if i == 0 else '0' + '0'.join(bin(i + 1)[3:]) + '1')
82
83
84def sie2bitstore(i: Union[str, int]) -> ConstBitStore:
85 i = int(i)
86 if i == 0:
87 return ConstBitStore.from_bin('1')
88 else:
89 return uie2bitstore(abs(i)) + (ConstBitStore.from_bin('1') if i < 0 else ConstBitStore.from_bin('0'))
90
91
92def bfloat2bitstore(f: Union[str, float], big_endian: bool) -> ConstBitStore:
93 f = float(f)
94 fmt = '>f' if big_endian else '<f'
95 try:
96 b = struct.pack(fmt, f)
97 except OverflowError:
98 # For consistency, we overflow to 'inf'.
99 b = struct.pack(fmt, float('inf') if f > 0 else float('-inf'))
100 return ConstBitStore.from_bytes(b[0:2]) if big_endian else ConstBitStore.from_bytes(b[2:4])
101
102
103def p4binary2bitstore(f: Union[str, float]) -> ConstBitStore:
104 f = float(f)
105 u = p4binary_fmt.float_to_int8(f)
106 return helpers.int2bitstore(u, 8, False)
107
108
109def p3binary2bitstore(f: Union[str, float]) -> ConstBitStore:
110 f = float(f)
111 u = p3binary_fmt.float_to_int8(f)
112 return helpers.int2bitstore(u, 8, False)
113
114
115def e4m3mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
116 f = float(f)
117 if bitstring.options.mxfp_overflow == 'saturate':
118 u = e4m3mxfp_saturate_fmt.float_to_int(f)
119 else:
120 u = e4m3mxfp_overflow_fmt.float_to_int(f)
121 return helpers.int2bitstore(u, 8, False)
122
123
124def e5m2mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
125 f = float(f)
126 if bitstring.options.mxfp_overflow == 'saturate':
127 u = e5m2mxfp_saturate_fmt.float_to_int(f)
128 else:
129 u = e5m2mxfp_overflow_fmt.float_to_int(f)
130 return helpers.int2bitstore(u, 8, False)
131
132
133def e3m2mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
134 f = float(f)
135 if math.isnan(f):
136 raise ValueError("Cannot convert float('nan') to e3m2mxfp format as it has no representation for it.")
137 u = e3m2mxfp_fmt.float_to_int(f)
138 return helpers.int2bitstore(u, 6, False)
139
140
141def e2m3mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
142 f = float(f)
143 if math.isnan(f):
144 raise ValueError("Cannot convert float('nan') to e2m3mxfp format as it has no representation for it.")
145 u = e2m3mxfp_fmt.float_to_int(f)
146 return helpers.int2bitstore(u, 6, False)
147
148
149def e2m1mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
150 f = float(f)
151 if math.isnan(f):
152 raise ValueError("Cannot convert float('nan') to e2m1mxfp format as it has no representation for it.")
153 u = e2m1mxfp_fmt.float_to_int(f)
154 return helpers.int2bitstore(u, 4, False)
155
156
157e8m0mxfp_allowed_values = [float(2 ** x) for x in range(-127, 128)]
158
159
160def e8m0mxfp2bitstore(f: Union[str, float]) -> ConstBitStore:
161 f = float(f)
162 if math.isnan(f):
163 return ConstBitStore.from_bin('11111111')
164 try:
165 i = e8m0mxfp_allowed_values.index(f)
166 except ValueError:
167 raise ValueError(f"{f} is not a valid e8m0mxfp value. It must be exactly 2 ** i, for -127 <= i <= 127 or float('nan') as no rounding will be done.")
168 return helpers.int2bitstore(i, 8, False)
169
170
171def mxint2bitstore(f: Union[str, float]) -> ConstBitStore:
172 f = float(f)
173 if math.isnan(f):
174 raise ValueError("Cannot convert float('nan') to mxint format as it has no representation for it.")
175 f *= 2 ** 6 # Remove the implicit scaling factor
176 if f > 127: # 1 + 63/64
177 return ConstBitStore.from_bin('01111111')
178 if f <= -128: # -2
179 return ConstBitStore.from_bin('10000000')
180 # Want to round to nearest, so move by 0.5 away from zero and round down by converting to int
181 if f >= 0.0:
182 f += 0.5
183 i = int(f)
184 # For ties-round-to-even
185 if f - i == 0.0 and i % 2:
186 i -= 1
187 else:
188 f -= 0.5
189 i = int(f)
190 if f - i == 0.0 and i % 2:
191 i += 1
192 return helpers.int2bitstore(i, 8, True)