Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/bitstring/mxfp.py: 35%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import array
2import math
3import struct
4import bitarray
5from bitstring.luts import mxfp_luts_compressed
6import zlib
7from typing import Optional
9def round_to_nearest_ties_to_even(lut_int_to_float, lower: int, f: float) -> Optional[int]:
10 upper = lower + 1
11 # Special case for LUTs without a negative zero.
12 lower_float = 0.0 if lower == 128 else lut_int_to_float[lower]
13 upper_float = lut_int_to_float[upper]
14 if upper_float < lower_float:
15 lower, upper = upper, lower
16 lower_float, upper_float = upper_float, lower_float
17 if f == lower_float:
18 return lower
19 if f == upper_float:
20 return upper
21 if lower_float < f < upper_float:
22 d1 = f - lower_float
23 d2 = upper_float - f
24 if d1 < d2:
25 return lower
26 if d2 < d1:
27 return upper
28 return lower if lower % 2 == 0 else upper
29 return None
32class MXFPFormat:
33 """Defining an MXFP micro-scaling floating point format"""
35 def __init__(self, exp_bits: int, mantissa_bits: int, bias: int, mxfp_overflow: str):
36 self.exp_bits = exp_bits
37 self.mantissa_bits = mantissa_bits
38 self.bias = bias
39 self.mxfp_overflow = mxfp_overflow
41 self.pos_clamp_value = (1 << (self.exp_bits + self.mantissa_bits)) - 1
42 self.neg_clamp_value = (1 << (1 + self.exp_bits + self.mantissa_bits)) - 1
44 # Special cases for e4m3 and e5m2
45 if self.exp_bits == 4 and self.mantissa_bits == 3:
46 if self.mxfp_overflow == 'saturate':
47 self.pos_clamp_value = 0b01111110 # 448
48 self.neg_clamp_value = 0b11111110 # -448
49 else:
50 self.pos_clamp_value = self.neg_clamp_value = 0b11111111 # NaN
51 if self.exp_bits == 5 and self.mantissa_bits == 2:
52 if self.mxfp_overflow == 'saturate':
53 self.pos_clamp_value = 0b01111011 # 57344
54 self.neg_clamp_value = 0b11111011 # -57344
55 else:
56 self.pos_clamp_value = 0b01111100 # +inf
57 self.neg_clamp_value = 0b11111100 # -inf
59 # If we calculate these LUTs now now it creates a bootstrap problem in generate_luts.py.
60 self.lut_float16_to_mxfp = None
61 self.lut_int_to_float = None
63 def __str__(self):
64 return f"MXFPFormat(exp_bits={self.exp_bits}, mantissa_bits={self.mantissa_bits}, bias={self.bias}, mxfp_overflow='{self.mxfp_overflow}')"
66 def decompress_luts(self):
67 int_to_float_compressed, float16_to_mxfp_compressed = mxfp_luts_compressed[(self.exp_bits, self.mantissa_bits, self.bias, self.mxfp_overflow)]
68 self.lut_float16_to_mxfp = zlib.decompress(float16_to_mxfp_compressed)
69 dec = zlib.decompress(int_to_float_compressed)
70 self.lut_int_to_float = struct.unpack(f'<{len(dec) // 4}f', dec)
72 def create_luts(self):
73 self.lut_int_to_float = self.createLUT_for_int_to_float()
74 self.lut_float16_to_mxfp = self.createLUT_for_float16_to_mxfp()
76 def float_to_int(self, f: float) -> int:
77 """Given a Python float convert to the best mxfp float (expressed as an int) that represents it."""
78 # First convert the float to a float16, then a 16 bit uint
79 try:
80 b = struct.pack('>e', f)
81 except (OverflowError, struct.error):
82 # Return the largest representable positive or negative value
83 return self.pos_clamp_value if f > 0 else self.neg_clamp_value
85 f16_int = int.from_bytes(b, byteorder='big')
86 # Then use this as an index to our large LUT
87 return self.lut_float16_to_mxfp[f16_int]
89 def slow_float_to_int(self, f: float) -> int:
90 # Slow, but easier to follow than the faster version.
91 # The output int has the binary sequence needed for the float.
92 length = 1 + self.exp_bits + self.mantissa_bits
93 values = 1 << length
94 if f >= 0:
95 # Positive, so top bit is not set
96 for i in range(values // 2 - 1):
97 upper = self.lut_int_to_float[i + 1]
98 if upper == float('inf'):
99 break
100 x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f)
101 if x is not None:
102 return x
103 return self.pos_clamp_value
104 if f < 0:
105 # Negative, so top bit is set
106 for i in range(values // 2, values - 1):
107 lower = self.lut_int_to_float[i + 1]
108 if lower == float('-inf'):
109 break
110 x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f)
111 if x is not None:
112 return x
113 # Clip to negative max
114 return self.neg_clamp_value
115 assert math.isnan(f)
116 if length == 8:
117 return 0xff # Works for both e5m2 and e4m3
118 # For smaller lengths, NaN isn't supported so we instead return an invalid value to detect later
119 return 0xff
121 def createLUT_for_int_to_float(self) -> array.array:
122 """Create a LUT to convert an int in representing a MXFP float into a Python float"""
123 i2f = []
124 length = 1 + self.exp_bits + self.mantissa_bits
125 for i in range(1 << length):
126 b = bitarray.util.int2ba(i, length=length, endian='big', signed=False)
127 sign = b[0]
128 exponent = bitarray.util.ba2int(b[1:1 + self.exp_bits])
129 significand = b[1 + self.exp_bits:]
130 if exponent == 0:
131 significand = bitarray.bitarray('0') + significand
132 exponent = -self.bias + 1
133 else:
134 significand = bitarray.bitarray('1') + significand
135 exponent -= self.bias
136 f = float(bitarray.util.ba2int(significand)) / (2.0 ** self.mantissa_bits)
137 f *= 2 ** exponent
138 if length == 8:
139 # Some special cases
140 if self.exp_bits == 5:
141 if i in [0b01111100, 0b11111100]:
142 f = float('inf')
143 if i in [0b01111101, 0b11111101, 0b01111110, 0b11111110, 0b01111111, 0b11111111]:
144 f = float('nan')
145 if self.exp_bits == 4:
146 if i in [0b01111111, 0b11111111]:
147 f = float('nan')
148 i2f.append(f if not sign else -f)
149 return array.array('f', i2f)
151 def createLUT_for_float16_to_mxfp(self) -> bytes:
152 """Create a LUT to convert a float16 into a MXFP format"""
153 # Used to create the LUT that was compressed and stored for the fp8 code
154 length = 1 + self.exp_bits + self.mantissa_bits
155 if length == 8:
156 import gfloat
157 fi = gfloat.formats.format_info_ocp_e5m2 if self.exp_bits == 5 else gfloat.formats.format_info_ocp_e4m3
159 fp16_to_fp8 = bytearray(1 << 16)
160 for i in range(1 << 16):
161 b = struct.pack('>H', i)
162 f, = struct.unpack('>e', b)
163 fp = gfloat.round_float(fi, f, sat=self.mxfp_overflow == 'saturate')
164 if math.isnan(fp):
165 fp8_i = 0b11111111
166 else:
167 fp8_i = self.lut_int_to_float.index(fp)
168 fp16_to_fp8[i] = fp8_i
169 return bytes(fp16_to_fp8)
170 else:
171 assert length in [4, 6]
172 fp16_to_fp8 = bytearray(1 << 16)
173 for i in range(1 << 16):
174 b = struct.pack('>H', i)
175 f, = struct.unpack('>e', b)
176 fp8_i = self.slow_float_to_int(f)
177 if fp8_i == 1 << (self.exp_bits + self.mantissa_bits):
178 # Got back int representing binary digits for negative zero. Just convert to positive zero instead.
179 fp8_i = 0
180 fp16_to_fp8[i] = fp8_i
181 return bytes(fp16_to_fp8)
184e2m1mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=1, bias=1, mxfp_overflow='saturate')
185e2m3mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=3, bias=1, mxfp_overflow='saturate')
186e3m2mxfp_fmt = MXFPFormat(exp_bits=3, mantissa_bits=2, bias=3, mxfp_overflow='saturate')
187e4m3mxfp_saturate_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='saturate')
188e5m2mxfp_saturate_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='saturate')
189e4m3mxfp_overflow_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='overflow')
190e5m2mxfp_overflow_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='overflow')
193def decompress_luts():
194 e2m1mxfp_fmt.decompress_luts()
195 e2m3mxfp_fmt.decompress_luts()
196 e3m2mxfp_fmt.decompress_luts()
197 e4m3mxfp_saturate_fmt.decompress_luts()
198 e5m2mxfp_saturate_fmt.decompress_luts()
199 e4m3mxfp_overflow_fmt.decompress_luts()
200 e5m2mxfp_overflow_fmt.decompress_luts()