Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/bitstring/fp8.py: 41%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

66 statements  

1""" 

2The 8-bit float formats used here are from a proposal supported by Graphcore, AMD and Qualcomm. 

3See https://arxiv.org/abs/2206.02915 

4 

5""" 

6 

7import struct 

8import zlib 

9import array 

10import bitarray 

11from bitstring.luts import binary8_luts_compressed 

12import math 

13 

14 

15class Binary8Format: 

16 """8-bit floating point formats based on draft IEEE binary8""" 

17 

18 def __init__(self, exp_bits: int, bias: int): 

19 self.exp_bits = exp_bits 

20 self.bias = bias 

21 self.pos_clamp_value = 0b01111111 

22 self.neg_clamp_value = 0b11111111 

23 

24 def __str__(self): 

25 return f"Binary8Format(exp_bits={self.exp_bits}, bias={self.bias})" 

26 

27 def decompress_luts(self): 

28 binary8_to_float_compressed, float16_to_binary8_compressed = binary8_luts_compressed[(self.exp_bits, self.bias)] 

29 self.lut_float16_to_binary8 = zlib.decompress(float16_to_binary8_compressed) 

30 dec = zlib.decompress(binary8_to_float_compressed) 

31 self.lut_binary8_to_float = struct.unpack(f'<{len(dec) // 4}f', dec) 

32 

33 def create_luts(self): 

34 self.lut_binary8_to_float = self.createLUT_for_binary8_to_float() 

35 self.lut_float16_to_binary8 = self.createLUT_for_float16_to_binary8() 

36 

37 def float_to_int8(self, f: float) -> int: 

38 """Given a Python float convert to the best float8 (expressed as an integer in 0-255 range).""" 

39 # First convert the float to a float16, then a 16 bit uint 

40 try: 

41 b = struct.pack('>e', f) 

42 except (OverflowError, struct.error): 

43 # Return the largest representable positive or negative value 

44 return self.pos_clamp_value if f > 0 else self.neg_clamp_value 

45 f16_int = int.from_bytes(b, byteorder='big') 

46 # Then use this as an index to our large LUT 

47 return self.lut_float16_to_binary8[f16_int] 

48 

49 def createLUT_for_float16_to_binary8(self) -> bytes: 

50 # Used to create the LUT that was compressed and stored for the fp8 code 

51 import gfloat 

52 fi = gfloat.formats.format_info_p3109(8 - self.exp_bits) 

53 fp16_to_fp8 = bytearray(1 << 16) 

54 for i in range(1 << 16): 

55 b = struct.pack('>H', i) 

56 f, = struct.unpack('>e', b) 

57 fp = gfloat.round_float(fi, f) 

58 if math.isnan(fp): 

59 fp8_i = 0b10000000 

60 else: 

61 fp8_i = self.lut_binary8_to_float.index(fp) 

62 fp16_to_fp8[i] = fp8_i 

63 return bytes(fp16_to_fp8) 

64 

65 def createLUT_for_binary8_to_float(self): 

66 """Create a LUT to convert an int in range 0-255 representing a float8 into a Python float""" 

67 i2f = [] 

68 for i in range(256): 

69 b = bitarray.util.int2ba(i, length=8, endian='big', signed=False) 

70 sign = b[0] 

71 exponent = bitarray.util.ba2int(b[1:1 + self.exp_bits]) 

72 significand = b[1 + self.exp_bits:] 

73 if exponent == 0: 

74 significand = bitarray.bitarray('0') + significand 

75 exponent = -self.bias + 1 

76 else: 

77 significand = bitarray.bitarray('1') + significand 

78 exponent -= self.bias 

79 f = float(bitarray.util.ba2int(significand)) / (2.0 ** (7 - self.exp_bits)) 

80 f *= 2 ** exponent 

81 i2f.append(f if not sign else -f) 

82 # One special case for minus zero 

83 i2f[0b10000000] = float('nan') 

84 # and for plus and minus infinity 

85 i2f[0b01111111] = float('inf') 

86 i2f[0b11111111] = float('-inf') 

87 return array.array('f', i2f) 

88 

89 

90# We create the 1.5.2 and 1.4.3 formats. 

91p4binary_fmt = Binary8Format(exp_bits=4, bias=8) 

92p3binary_fmt = Binary8Format(exp_bits=5, bias=16) 

93 

94 

95def decompress_luts(): 

96 p4binary_fmt.decompress_luts() 

97 p3binary_fmt.decompress_luts()