Skip to content

Commit c147495

Browse files
committed
add nvfp4 packing
1 parent be30822 commit c147495

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
import numpy
3+
4+
x = torch.Tensor(
5+
[[-0.5000, -6.0000, -0.5000, -1.5000, -1.0000, 6.0000, -0.0000, -0.0000],
6+
[-1.0000, -6.0000, -0.5000, -0.0000, 0.5000, 0.5000, -0.0000, -0.0000],
7+
[-3.0000, -6.0000, -0.5000, -2.0000, -0.5000, -1.5000, -0.0000, -0.0000],
8+
[ 1.5000, 6.0000, -0.0000, -0.5000, 1.0000, 1.0000, -0.0000, 0.0000]]
9+
)
10+
11+
m, n = x.shape
12+
13+
FLOAT_TO_E2M1 = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]
14+
conversion_dict = {}
15+
16+
# Dictionary between fp4 value and index
17+
for i in range(len(FLOAT_TO_E2M1)):
18+
conversion_dict[FLOAT_TO_E2M1[i]] = i
19+
20+
21+
x_numpy = x.to("cpu").numpy()
22+
x_index = numpy.array([[conversion_dict[i] for i in row] for row in x_numpy], dtype=numpy.uint8)
23+
x_index_bits = numpy.unpackbits(x_index)
24+
25+
packed_shape = numpy.zeros([x_index_bits.shape[0] // 2], numpy.uint8)
26+
start = 0
27+
end = 16
28+
i = 0
29+
30+
# janky bit manipulation
31+
while end < len(x_index_bits):
32+
subset = x_index_bits[start:end]
33+
subset_a = subset[4:8]
34+
subset_b = subset[12:16]
35+
packed_shape[i+4:i+8] = subset_a
36+
packed_shape[i:i+4] = subset_b
37+
start = end
38+
end = start + 16
39+
i += 8
40+
41+
packed = numpy.packbits(packed_shape)
42+
packed = torch.Tensor(packed).to(torch.uint8)
43+
packed = packed.reshape(m, n // 2)
44+
45+
46+
# from vLLM
47+
def cast_from_fp4(x, m, n):
48+
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
49+
v_2nd = x & 0xF
50+
v_1st = (x >> 4) & 0xF
51+
c = torch.stack((v_2nd, v_1st), dim=-1)
52+
out = torch.tensor([FLOAT_TO_E2M1[x] for x in c.flatten()])
53+
out = out.reshape(m, n).to(torch.float32)
54+
return out
55+
56+
57+
out = cast_from_fp4(packed, m, n)
58+
print(out.shape, packed.shape)
59+
print(out)
60+
assert torch.equal(out, x)

0 commit comments

Comments
 (0)