|
| 1 | +import torch |
| 2 | +import numpy |
| 3 | + |
| 4 | +x = torch.Tensor( |
| 5 | + [[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, -0.0000, -0.0000], |
| 6 | + [-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, -0.0000], |
| 7 | + [-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000], |
| 8 | + [ 1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000]] |
| 9 | +) |
| 10 | + |
| 11 | +m, n = x.shape |
| 12 | + |
| 13 | +FLOAT_TO_E2M1 = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] |
| 14 | +conversion_dict = {} |
| 15 | + |
| 16 | +# Dictionary between fp4 value and index |
| 17 | +for i in range(len(FLOAT_TO_E2M1)): |
| 18 | + conversion_dict[FLOAT_TO_E2M1[i]] = i |
| 19 | + |
| 20 | + |
| 21 | +x_numpy = x.to("cpu").numpy() |
| 22 | +x_index = numpy.array([[conversion_dict[i] for i in row] for row in x_numpy], dtype=numpy.uint8) |
| 23 | +x_index_bits = numpy.unpackbits(x_index) |
| 24 | + |
| 25 | +packed_shape = numpy.zeros([x_index_bits.shape[0] // 2], numpy.uint8) |
| 26 | +start = 0 |
| 27 | +end = 16 |
| 28 | +i = 0 |
| 29 | + |
| 30 | +# janky bit manipulation |
| 31 | +while end < len(x_index_bits): |
| 32 | + subset = x_index_bits[start:end] |
| 33 | + subset_a = subset[4:8] |
| 34 | + subset_b = subset[12:16] |
| 35 | + packed_shape[i+4:i+8] = subset_a |
| 36 | + packed_shape[i:i+4] = subset_b |
| 37 | + start = end |
| 38 | + end = start + 16 |
| 39 | + i += 8 |
| 40 | + |
| 41 | +packed = numpy.packbits(packed_shape) |
| 42 | +packed = torch.Tensor(packed).to(torch.uint8) |
| 43 | +packed = packed.reshape(m, n // 2) |
| 44 | + |
| 45 | + |
| 46 | +# from vLLM |
| 47 | +def cast_from_fp4(x, m, n): |
| 48 | + # The fp4 values are packed in uint8 as [v_1st | v_2nd] |
| 49 | + v_2nd = x & 0xF |
| 50 | + v_1st = (x >> 4) & 0xF |
| 51 | + c = torch.stack((v_2nd, v_1st), dim=-1) |
| 52 | + out = torch.tensor([FLOAT_TO_E2M1[x] for x in c.flatten()]) |
| 53 | + out = out.reshape(m, n).to(torch.float32) |
| 54 | + return out |
| 55 | + |
| 56 | + |
| 57 | +out = cast_from_fp4(packed, m, n) |
| 58 | +print(out.shape, packed.shape) |
| 59 | +print(out) |
| 60 | +assert torch.equal(out, x) |
0 commit comments