Skip to content

Commit 2948801

Browse files
Support MXFP6 packing and fused unpack-dequantise kernel (conflicts resolved) (#1810)
* Added MXFP6 packing for inference with tests * Updated MXFP6 packing to use MXLinearConfig, amended pytests * Added simpler PyTorch implementation of FP6 packing, for reference purposes * FP6 packing on by default for inference, no training support * Fixed test failure where Triton unavailable, ran linter * Removed references to torch custom op API in the case where torch version < 2.4 * Removed references to torch custom op API in the case where torch version < 2.4
1 parent bc4f51d commit 2948801

File tree

8 files changed

+815
-53
lines changed

8 files changed

+815
-53
lines changed

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from torch.utils._triton import has_triton
1010

11-
import torchao.prototype.mx_formats.config as config
1211
from torchao.prototype.mx_formats.constants import (
1312
DTYPE_FP4,
1413
DTYPE_FP6_E2M3,
@@ -26,7 +25,10 @@
2625
f32_to_f6_e3m2_unpacked,
2726
get_bits,
2827
pack_uint4,
28+
pack_uint6,
2929
triton_f4_to_bf16,
30+
triton_f6_e2m3_to_bf16,
31+
triton_f6_e3m2_to_bf16,
3032
unpack_uint4,
3133
)
3234
from torchao.prototype.mx_formats.fp_format_spec import (
@@ -329,12 +331,16 @@ def test_fp4_triton_unscaled_cast():
329331
def test_fp4_triton_scaled_cast():
330332
size = (256,)
331333
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
332-
mxtensor = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
333-
334-
f32_ref = mxtensor.to_dtype(torch.float)
335-
config.use_fp4_custom_triton_dequant_kernel = True
336-
f32_triton = mxtensor.to_dtype(torch.float)
337-
config.use_fp4_custom_triton_dequant_kernel = False
334+
mxtensor_ref = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
335+
mxtensor_triton = MXTensor.to_mx(
336+
orig_vals,
337+
block_size=32,
338+
elem_dtype=DTYPE_FP4,
339+
use_fp4_custom_triton_dequant_kernel=True,
340+
)
341+
342+
f32_ref = mxtensor_ref.to_dtype(torch.float)
343+
f32_triton = mxtensor_triton.to_dtype(torch.float)
338344
assert torch.all(torch.eq(f32_ref, f32_triton))
339345

340346

@@ -411,3 +417,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
411417

412418
f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
413419
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)
420+
421+
422+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
423+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
424+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
425+
@pytest.mark.skipif(
426+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
427+
)
428+
def test_fp6_e2m3_pack_unpack():
429+
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
430+
"cuda"
431+
)
432+
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
433+
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
434+
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
435+
orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to(
436+
torch.float32
437+
)
438+
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
439+
440+
441+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
442+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
443+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
444+
@pytest.mark.skipif(
445+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
446+
)
447+
def test_fp6_e3m2_pack_unpack():
448+
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
449+
"cuda"
450+
)
451+
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
452+
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
453+
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
454+
orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to(
455+
torch.float32
456+
)
457+
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ def test_linear_eager(elem_dtype, bias, input_shape):
5959
"""
6060
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
6161
grad_shape = list(input_shape)
62-
grad_shape[-1] = 6
62+
grad_shape[-1] = 8
6363

6464
m = nn.Sequential(
65-
nn.Linear(8, 6, bias=bias, device="cuda"),
65+
nn.Linear(8, 8, bias=bias, device="cuda"),
6666
)
6767
m_mx = copy.deepcopy(m)
6868
config = MXLinearConfig(
69-
block_size=2,
69+
block_size=4,
7070
elem_dtype=elem_dtype[0],
7171
elem_dtype_weight_override=elem_dtype[1],
7272
elem_dtype_grad_output_override=elem_dtype[2],
@@ -151,14 +151,14 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
151151
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
152152
def test_activation_checkpointing():
153153
input_shape = (2, 4)
154-
grad_shape = (2, 6)
154+
grad_shape = (2, 8)
155155
elem_dtype = torch.float8_e4m3fn
156156

157157
m = nn.Sequential(
158-
nn.Linear(4, 6, bias=True, device="cuda"),
159-
nn.Linear(6, 6, bias=True, device="cuda"),
158+
nn.Linear(4, 8, bias=True, device="cuda"),
159+
nn.Linear(8, 8, bias=True, device="cuda"),
160160
)
161-
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
161+
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
162162
swap_linear_with_mx_linear(m, config=config)
163163

164164
x = torch.randn(*input_shape, device="cuda").requires_grad_()
@@ -240,10 +240,10 @@ def test_inference_linear(elem_dtype, bias, input_shape):
240240
"""
241241
Smoke test for inference linear module with mx weight
242242
"""
243-
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
243+
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
244244
m = m.cuda()
245245
m_mx = copy.deepcopy(m)
246-
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
246+
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
247247
swap_linear_with_mx_inference_linear(m_mx, config=config)
248248

249249
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
@@ -268,10 +268,10 @@ def test_inference_compile_simple(elem_dtype):
268268
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
269269
if not is_sm_at_least_89():
270270
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
271-
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
271+
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
272272
m = m.cuda()
273273
m_mx = copy.deepcopy(m)
274-
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
274+
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
275275
swap_linear_with_mx_inference_linear(m_mx, config=config)
276276
m_mx = torch.compile(m_mx, fullgraph="true")
277277

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
DTYPE_FP6_E3M2,
1717
SUPPORTED_ELEM_DTYPES,
1818
)
19-
from torchao.prototype.mx_formats.custom_cast import pack_uint4
19+
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
2020
from torchao.prototype.mx_formats.mx_tensor import (
2121
E8M0_EXPONENT_NAN_VAL,
2222
MXTensor,
@@ -75,7 +75,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
7575
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
7676
def test_hello_world(elem_dtype):
7777
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
78-
block_size = 2
78+
block_size = 4
7979
_test_mx(data, elem_dtype, block_size)
8080

8181

@@ -92,7 +92,7 @@ def test_realistic_numerics(elem_dtype, scale_calculation_mode):
9292
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
9393
def test_all_zeros(elem_dtype):
9494
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
95-
block_size = 2
95+
block_size = 4
9696
_test_mx(data, elem_dtype, block_size)
9797

9898

@@ -102,7 +102,7 @@ def test_some_zeros(elem_dtype):
102102
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
103103
data[0, :] = 0.0
104104
data[:, 2] = 0.0
105-
block_size = 2
105+
block_size = 4
106106
_test_mx(data, elem_dtype, block_size)
107107

108108

@@ -114,33 +114,46 @@ def test_exponent_nan_in(elem_dtype):
114114
value is set to is NaN
115115
"""
116116
tensor_hp = torch.tensor(
117-
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
117+
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
118118
)
119-
block_size = 2
119+
block_size = 4
120120
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
121121
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
122122
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
123123

124124

125125
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
126126
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
127-
def test_exponent_nan_out(elem_dtype):
127+
@pytest.mark.parametrize("pack_fp6", [False, True])
128+
def test_exponent_nan_out(elem_dtype, pack_fp6):
128129
"""
129130
If block exponent value is NaN, the MX tensor block value is NaN
130131
"""
131132
scale_e8m0_bits = torch.tensor(
132-
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
133+
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
133134
)
135+
136+
block_size = 4
137+
134138
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
135-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
139+
data_bits = torch.tensor(
140+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
141+
) # noqa: E501
136142
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
137-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
143+
data_bits = torch.tensor(
144+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
145+
) # noqa: E501
146+
if pack_fp6:
147+
data_bits = data_bits.reshape(-1, block_size)
148+
data_bits = pack_uint6(data_bits)
138149
elif elem_dtype == DTYPE_FP4:
139-
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
150+
data_bits = torch.tensor(
151+
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
152+
) # noqa: E501
140153
data_bits = pack_uint4(data_bits)
141154
else:
142155
raise AssertionError("unsupported")
143-
block_size = 2
156+
block_size = 4
144157
use_fp4_custom_triton_dequant_kernel = False
145158
tensor_mx = MXTensor(
146159
scale_e8m0_bits,
@@ -150,10 +163,11 @@ def test_exponent_nan_out(elem_dtype):
150163
torch.float,
151164
use_fp4_custom_triton_dequant_kernel,
152165
MXGemmKernelChoice.EMULATED,
166+
pack_fp6,
153167
)
154168
tensor_hp = tensor_mx.to_dtype(torch.float)
155-
assert torch.all(torch.isnan(tensor_hp[0:1]))
156-
assert not torch.any(torch.isnan(tensor_hp[2:]))
169+
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
170+
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))
157171

158172

159173
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -162,24 +176,26 @@ def test_ranks(elem_dtype):
162176
"""
163177
The reshaping logic works for various ranks
164178
"""
165-
B = 2
166-
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
179+
B = 4
180+
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
167181
for s in shapes:
168182
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
169183
_test_mx(tensor_hp, elem_dtype, B)
170184

171185

172186
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
173187
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
174-
def test_block_sizes(elem_dtype):
188+
@pytest.mark.parametrize("B", [1, 4, 32])
189+
def test_block_sizes(elem_dtype, B):
175190
"""
176191
Smoke test for various block sizes
177192
"""
178-
for B in (1, 2, 32):
179-
if B == 1 and elem_dtype == DTYPE_FP4:
180-
pytest.skip("unsupported configuration")
181-
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
182-
_test_mx(tensor_hp, elem_dtype, B)
193+
if B == 1 and elem_dtype == DTYPE_FP4:
194+
pytest.skip("unsupported configuration")
195+
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
196+
pytest.skip("unsupported configuration")
197+
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
198+
_test_mx(tensor_hp, elem_dtype, B)
183199

184200

185201
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -224,14 +240,30 @@ def test_cast_autograd(elem_dtype):
224240
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)
225241

226242

243+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
227244
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
228245
def test_view(elem_dtype):
229-
x = torch.randn(1, 2, 4)
230-
block_size = 2
246+
x = torch.randn(1, 2, 4, device="cuda")
247+
block_size = 4
231248
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
232249
x_mx_2 = x_mx.view(2, 4) # noqa: F841
233250

234251

252+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
253+
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
254+
@pytest.mark.parametrize("pack_fp6", [False, True])
255+
def test_fp6_packing(elem_dtype, pack_fp6):
256+
x = torch.randn(1, 2, 4, device="cuda")
257+
block_size = 4
258+
x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6)
259+
if pack_fp6:
260+
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
261+
else:
262+
expected_packed_shape = x.shape
263+
264+
assert x_mx._data.shape == expected_packed_shape
265+
266+
235267
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
236268
@pytest.mark.skipif(
237269
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
@@ -253,7 +285,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
253285
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
254286
else:
255287
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
256-
block_size = 2
288+
block_size = 4
257289
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
258290

259291
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
@@ -269,13 +301,15 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
269301
to_dtype_c = torch.compile(to_dtype, fullgraph=True)
270302

271303
use_fp4_custom_triton_dequant_kernel = False
304+
pack_fp6 = False
272305
x_mx_dq = to_dtype(
273306
x_mx._data,
274307
x_mx._scale_e8m0,
275308
x_mx._elem_dtype,
276309
x_mx._block_size,
277310
hp_dtype, # noqa: E501
278311
use_fp4_custom_triton_dequant_kernel,
312+
pack_fp6,
279313
)
280314
x_mx_c_dq = to_dtype_c(
281315
x_mx_c._data,
@@ -284,6 +318,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
284318
x_mx_c._block_size,
285319
hp_dtype,
286320
use_fp4_custom_triton_dequant_kernel,
321+
pack_fp6,
287322
)
288323
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
289324

torchao/prototype/mx_formats/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class MXLinearConfig:
6161
# If True, uses a custom triton kernel for fp4 dequantize
6262
use_fp4_custom_triton_dequant_kernel: bool = False
6363

64+
# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
65+
# kernels (fused unpack/dequantize). Training not currently supported.
66+
pack_fp6 = True if hasattr(torch.library, "custom_op") else False
67+
6468
def __post_init__(self):
6569
# validate elem_dtype and its overrides
6670
assert (

0 commit comments

Comments
 (0)