Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/bitstring/mxfp.py: 35%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

158 statements  

1import array 

2import math 

3import struct 

4import bitarray 

5from bitstring.luts import mxfp_luts_compressed 

6import zlib 

7from typing import Optional 

8 

9 

10def round_to_nearest_ties_to_even(lut_int_to_float, lower: int, f: float) -> Optional[int]: 

11 upper = lower + 1 

12 # Special case for LUTs without a negative zero. 

13 lower_float = 0.0 if lower == 128 else lut_int_to_float[lower] 

14 upper_float = lut_int_to_float[upper] 

15 if upper_float < lower_float: 

16 lower, upper = upper, lower 

17 lower_float, upper_float = upper_float, lower_float 

18 if f == lower_float: 

19 return lower 

20 if f == upper_float: 

21 return upper 

22 if lower_float < f < upper_float: 

23 d1 = f - lower_float 

24 d2 = upper_float - f 

25 if d1 < d2: 

26 return lower 

27 if d2 < d1: 

28 return upper 

29 return lower if lower % 2 == 0 else upper 

30 return None 

31 

32 

33class MXFPFormat: 

34 """Defining an MXFP micro-scaling floating point format""" 

35 

36 def __init__(self, exp_bits: int, mantissa_bits: int, bias: int, mxfp_overflow: str): 

37 self.exp_bits = exp_bits 

38 self.mantissa_bits = mantissa_bits 

39 self.bias = bias 

40 self.mxfp_overflow = mxfp_overflow 

41 

42 self.pos_clamp_value = (1 << (self.exp_bits + self.mantissa_bits)) - 1 

43 self.neg_clamp_value = (1 << (1 + self.exp_bits + self.mantissa_bits)) - 1 

44 

45 # Special cases for e4m3 and e5m2 

46 if self.exp_bits == 4 and self.mantissa_bits == 3: 

47 if self.mxfp_overflow == 'saturate': 

48 self.pos_clamp_value = 0b01111110 # 448 

49 self.neg_clamp_value = 0b11111110 # -448 

50 else: 

51 self.pos_clamp_value = self.neg_clamp_value = 0b11111111 # NaN 

52 if self.exp_bits == 5 and self.mantissa_bits == 2: 

53 if self.mxfp_overflow == 'saturate': 

54 self.pos_clamp_value = 0b01111011 # 57344 

55 self.neg_clamp_value = 0b11111011 # -57344 

56 else: 

57 self.pos_clamp_value = 0b01111100 # +inf 

58 self.neg_clamp_value = 0b11111100 # -inf 

59 

60 # If we calculate these LUTs now it creates a bootstrap problem in generate_luts.py. 

61 self.lut_float16_to_mxfp = None 

62 self.lut_int_to_float = None 

63 

64 def __str__(self): 

65 return f"MXFPFormat(exp_bits={self.exp_bits}, mantissa_bits={self.mantissa_bits}, bias={self.bias}, mxfp_overflow='{self.mxfp_overflow}')" 

66 

67 def decompress_luts(self): 

68 int_to_float_compressed, float16_to_mxfp_compressed = mxfp_luts_compressed[(self.exp_bits, self.mantissa_bits, self.bias, self.mxfp_overflow)] 

69 self.lut_float16_to_mxfp = zlib.decompress(float16_to_mxfp_compressed) 

70 dec = zlib.decompress(int_to_float_compressed) 

71 self.lut_int_to_float = struct.unpack(f'<{len(dec) // 4}f', dec) 

72 

73 def create_luts(self): 

74 self.lut_int_to_float = self.createLUT_for_int_to_float() 

75 self.lut_float16_to_mxfp = self.createLUT_for_float16_to_mxfp() 

76 

77 def float_to_int(self, f: float) -> int: 

78 """Given a Python float convert to the best mxfp float (expressed as an int) that represents it.""" 

79 # First convert the float to a float16, then a 16 bit uint 

80 try: 

81 b = struct.pack('>e', f) 

82 except (OverflowError, struct.error): 

83 # Return the largest representable positive or negative value 

84 return self.pos_clamp_value if f > 0 else self.neg_clamp_value 

85 

86 f16_int = int.from_bytes(b, byteorder='big') 

87 # Then use this as an index to our large LUT 

88 return self.lut_float16_to_mxfp[f16_int] 

89 

90 def slow_float_to_int(self, f: float) -> int: 

91 # Slow, but easier to follow than the faster version. 

92 # The output int has the binary sequence needed for the float. 

93 length = 1 + self.exp_bits + self.mantissa_bits 

94 values = 1 << length 

95 # First get the NaN case out of the way 

96 if math.isnan(f): 

97 if length == 8: 

98 return 0xff # Works for both e5m2 and e4m3 

99 # For smaller lengths, NaN isn't supported so we instead return an invalid value to detect later 

100 return 0xff 

101 # This is so we can distinguish between 0.0 and -0.0 

102 is_positive = math.copysign(1.0, f) == 1.0 

103 if is_positive: 

104 # Positive, so top bit is not set 

105 for i in range(values // 2 - 1): 

106 upper = self.lut_int_to_float[i + 1] 

107 if upper == float('inf'): 

108 break 

109 x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f) 

110 if x is not None: 

111 return x 

112 return self.pos_clamp_value 

113 else: 

114 # Negative, so top bit is set 

115 for i in range(values // 2, values - 1): 

116 lower = self.lut_int_to_float[i + 1] 

117 if lower == float('-inf'): 

118 break 

119 x = round_to_nearest_ties_to_even(self.lut_int_to_float, i, f) 

120 if x is not None: 

121 return x 

122 # Clip to negative max 

123 return self.neg_clamp_value 

124 

125 def createLUT_for_int_to_float(self) -> array.array: 

126 """Create a LUT to convert an int in representing a MXFP float into a Python float""" 

127 i2f = [] 

128 length = 1 + self.exp_bits + self.mantissa_bits 

129 for i in range(1 << length): 

130 b = bitarray.util.int2ba(i, length=length, endian='big', signed=False) 

131 sign = b[0] 

132 exponent = bitarray.util.ba2int(b[1:1 + self.exp_bits]) 

133 significand = b[1 + self.exp_bits:] 

134 if exponent == 0: 

135 significand = bitarray.bitarray('0') + significand 

136 exponent = -self.bias + 1 

137 else: 

138 significand = bitarray.bitarray('1') + significand 

139 exponent -= self.bias 

140 f = float(bitarray.util.ba2int(significand)) / (2.0 ** self.mantissa_bits) 

141 f *= 2 ** exponent 

142 if length == 8: 

143 # Some special cases 

144 if self.exp_bits == 5: 

145 if i in [0b01111100, 0b11111100]: 

146 f = float('inf') 

147 if i in [0b01111101, 0b11111101, 0b01111110, 0b11111110, 0b01111111, 0b11111111]: 

148 f = float('nan') 

149 if self.exp_bits == 4: 

150 if i in [0b01111111, 0b11111111]: 

151 f = float('nan') 

152 i2f.append(f if not sign else -f) 

153 return array.array('f', i2f) 

154 

155 def createLUT_for_float16_to_mxfp(self) -> bytes: 

156 """Create a LUT to convert a float16 into a MXFP format""" 

157 # Used to create the LUT that was compressed and stored for the fp8 code 

158 length = 1 + self.exp_bits + self.mantissa_bits 

159 if length == 8: 

160 import gfloat 

161 from gfloat.formats import format_info_ocp_e5m2, format_info_ocp_e4m3 

162 fi = format_info_ocp_e5m2 if self.exp_bits == 5 else format_info_ocp_e4m3 

163 

164 fp16_to_fp8 = bytearray(1 << 16) 

165 for i in range(1 << 16): 

166 b = struct.pack('>H', i) 

167 f, = struct.unpack('>e', b) 

168 fp = gfloat.round_float(fi, f, sat=self.mxfp_overflow == 'saturate') 

169 if math.isnan(fp): 

170 fp8_i = 0b11111111 

171 else: 

172 # Special case for negative zero 

173 if fp == 0.0 and math.copysign(1.0, fp) == -1.0: 

174 fp8_i = 0b10000000 

175 else: 

176 fp8_i = self.lut_int_to_float.index(fp) 

177 fp16_to_fp8[i] = fp8_i 

178 return bytes(fp16_to_fp8) 

179 else: 

180 assert length in [4, 6] 

181 fp16_to_fp8 = bytearray(1 << 16) 

182 for i in range(1 << 16): 

183 b = struct.pack('>H', i) 

184 f, = struct.unpack('>e', b) 

185 fp8_i = self.slow_float_to_int(f) 

186 fp16_to_fp8[i] = fp8_i 

187 return bytes(fp16_to_fp8) 

188 

189 

190e2m1mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=1, bias=1, mxfp_overflow='saturate') 

191e2m3mxfp_fmt = MXFPFormat(exp_bits=2, mantissa_bits=3, bias=1, mxfp_overflow='saturate') 

192e3m2mxfp_fmt = MXFPFormat(exp_bits=3, mantissa_bits=2, bias=3, mxfp_overflow='saturate') 

193e4m3mxfp_saturate_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='saturate') 

194e5m2mxfp_saturate_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='saturate') 

195e4m3mxfp_overflow_fmt = MXFPFormat(exp_bits=4, mantissa_bits=3, bias=7, mxfp_overflow='overflow') 

196e5m2mxfp_overflow_fmt = MXFPFormat(exp_bits=5, mantissa_bits=2, bias=15, mxfp_overflow='overflow') 

197 

198 

199def decompress_luts(): 

200 e2m1mxfp_fmt.decompress_luts() 

201 e2m3mxfp_fmt.decompress_luts() 

202 e3m2mxfp_fmt.decompress_luts() 

203 e4m3mxfp_saturate_fmt.decompress_luts() 

204 e5m2mxfp_saturate_fmt.decompress_luts() 

205 e4m3mxfp_overflow_fmt.decompress_luts() 

206 e5m2mxfp_overflow_fmt.decompress_luts()