Skip to content

Commit 4bfd7c0

Browse files
authored
Add mx_fp4 path (#2201)
stack-info: PR: #2201, branch: drisspg/stack/54
1 parent ec15554 commit 4bfd7c0

File tree

11 files changed

+93
-56
lines changed

11 files changed

+93
-56
lines changed

test/prototype/mx_formats/test_kernels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from torch.utils._triton import has_triton
1010

1111
from torchao.prototype.mx_formats.constants import (
12-
DTYPE_FP4,
1312
DTYPE_FP6_E2M3,
1413
DTYPE_FP6_E3M2,
1514
F4_E2M1_EXP_BIAS,
@@ -335,11 +334,13 @@ def test_fp4_triton_unscaled_cast():
335334
def test_fp4_triton_scaled_cast():
336335
size = (256,)
337336
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
338-
mxtensor_ref = MXTensor.to_mx(orig_vals, block_size=32, elem_dtype=DTYPE_FP4)
337+
mxtensor_ref = MXTensor.to_mx(
338+
orig_vals, block_size=32, elem_dtype=torch.float4_e2m1fn_x2
339+
)
339340
mxtensor_triton = MXTensor.to_mx(
340341
orig_vals,
341342
block_size=32,
342-
elem_dtype=DTYPE_FP4,
343+
elem_dtype=torch.float4_e2m1fn_x2,
343344
use_fp4_custom_triton_dequant_kernel=True,
344345
)
345346

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
import torch.nn as nn
1212

1313
from torchao.prototype.mx_formats.config import (
14+
MXGemmKernelChoice,
1415
MXInferenceLinearConfig,
1516
MXLinearConfig,
1617
MXLinearRecipeName,
1718
)
1819
from torchao.prototype.mx_formats.constants import (
19-
DTYPE_FP4,
2020
DTYPE_FP6_E2M3,
2121
DTYPE_FP6_E3M2,
2222
SUPPORTED_ELEM_DTYPES,
@@ -29,15 +29,14 @@
2929
from torchao.quantization import quantize_
3030
from torchao.quantization.utils import compute_error
3131
from torchao.utils import (
32-
TORCH_VERSION_AT_LEAST_2_7,
3332
TORCH_VERSION_AT_LEAST_2_8,
3433
is_sm_at_least_89,
3534
is_sm_at_least_100,
3635
)
3736

3837
torch.manual_seed(2)
3938

40-
if not TORCH_VERSION_AT_LEAST_2_7:
39+
if not TORCH_VERSION_AT_LEAST_2_8:
4140
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
4241

4342

@@ -51,19 +50,28 @@ def run_around_tests():
5150
torch._dynamo.reset()
5251

5352

54-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
55-
@pytest.mark.parametrize(
56-
"elem_dtype",
57-
(
53+
elem_dtypes = (
54+
[
5855
# test each dtype
5956
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
6057
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
6158
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
62-
(DTYPE_FP4, DTYPE_FP4, DTYPE_FP4),
59+
(torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
6360
# only test one type of mixed-dtype overrides, to save testing time
64-
(torch.float8_e4m3fn, DTYPE_FP4, DTYPE_FP4),
65-
),
61+
(torch.float8_e4m3fn, torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2),
62+
]
63+
if TORCH_VERSION_AT_LEAST_2_8
64+
else [
65+
# test each dtype
66+
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
67+
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
68+
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
69+
]
6670
)
71+
72+
73+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
74+
@pytest.mark.parametrize("elem_dtype", elem_dtypes)
6775
@pytest.mark.parametrize("bias", [True, False])
6876
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
6977
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
@@ -155,7 +163,7 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
155163

156164
elem_dtype = torch.float8_e4m3fn
157165
if recipe_name == MXLinearRecipeName.MXFP4_CUTLASS:
158-
elem_dtype = DTYPE_FP4
166+
elem_dtype = torch.float4_e2m1fn_x2
159167

160168
config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
161169
config_real = MXLinearConfig.from_recipe_name(recipe_name)
@@ -375,12 +383,21 @@ def test_inference_print_str():
375383
assert "kernel=emulated" in s
376384

377385

386+
test_dtypes = (
387+
[torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
388+
if TORCH_VERSION_AT_LEAST_2_8
389+
else [
390+
torch.float8_e4m3fn,
391+
]
392+
)
393+
394+
378395
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
379396
@pytest.mark.skipif(
380397
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
381398
)
382399
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
383-
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn])
400+
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
384401
@pytest.mark.parametrize("bias", [True, False])
385402
@pytest.mark.parametrize("compile", [True, False])
386403
@torch.no_grad()
@@ -394,7 +411,16 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
394411

395412
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
396413
m_mx = copy.deepcopy(m)
397-
config = MXFPInferenceConfig()
414+
kernel_choice = (
415+
MXGemmKernelChoice.CUTLASS
416+
if elem_dtype == torch.float4_e2m1fn_x2
417+
else MXGemmKernelChoice.CUBLAS
418+
)
419+
config = MXFPInferenceConfig(
420+
activation_dtype=elem_dtype,
421+
weight_dtype=elem_dtype,
422+
gemm_kernel_choice=kernel_choice,
423+
)
398424
quantize_(m_mx, config=config)
399425
if compile:
400426
m_mx = torch.compile(m_mx, fullgraph=True)
@@ -403,4 +429,7 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
403429
y_ref = m(x)
404430
y_mx = m_mx(x)
405431
sqnr = compute_error(y_ref, y_mx)
406-
assert sqnr >= 25.0, f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
432+
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
433+
assert sqnr >= SQNR_THRESHOLD, (
434+
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
435+
)

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
from torchao.float8.float8_utils import compute_error
1212
from torchao.ops import mx_fp4_bf16
13-
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
13+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1414
from torchao.prototype.mx_formats.utils import to_blocked
15-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_100
15+
from torchao.utils import (
16+
TORCH_VERSION_AT_LEAST_2_8,
17+
is_sm_at_least_100,
18+
)
1619

17-
if not TORCH_VERSION_AT_LEAST_2_7:
20+
if not TORCH_VERSION_AT_LEAST_2_8:
1821
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1922

2023

@@ -25,7 +28,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
2528
a = torch.rand((M, K), dtype=dtype, device=device)
2629
b = torch.rand((N, K), dtype=dtype, device=device)
2730

28-
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
31+
fmt = torch.float8_e4m3fn if format == "fp8" else torch.float4_e2m1fn_x2
2932
mx_func = (
3033
partial(torch._scaled_mm, out_dtype=torch.bfloat16)
3134
if format == "fp8"
@@ -75,7 +78,9 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
7578
],
7679
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
7780
)
78-
@pytest.mark.parametrize("format", ["fp8", "fp4"])
81+
@pytest.mark.parametrize(
82+
"format", ["fp8", "fp4"] if TORCH_VERSION_AT_LEAST_2_8 else ["fp8"]
83+
)
7984
def test_matrix_multiplication(size, format):
8085
M, K, N = size
8186
sqnr = run_matrix_test(M, K, N, format)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
1414
from torchao.prototype.mx_formats.constants import (
15-
DTYPE_FP4,
1615
DTYPE_FP6_E2M3,
1716
DTYPE_FP6_E3M2,
1817
SUPPORTED_ELEM_DTYPES,
@@ -363,7 +362,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6):
363362
if pack_fp6:
364363
data_bits = data_bits.reshape(-1, block_size)
365364
data_bits = pack_uint6(data_bits)
366-
elif elem_dtype == DTYPE_FP4:
365+
elif elem_dtype == torch.float4_e2m1fn_x2:
367366
data_bits = torch.tensor(
368367
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
369368
) # noqa: E501
@@ -407,7 +406,7 @@ def test_block_sizes(elem_dtype, B):
407406
"""
408407
Smoke test for various block sizes
409408
"""
410-
if B == 1 and elem_dtype == DTYPE_FP4:
409+
if B == 1 and elem_dtype == torch.float4_e2m1fn_x2:
411410
pytest.skip("unsupported configuration")
412411
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
413412
pytest.skip("unsupported configuration")
@@ -422,7 +421,7 @@ def test_transpose(elem_dtype, fp4_triton):
422421
"""
423422
Verify that transposing an MX tensor works
424423
"""
425-
if elem_dtype != DTYPE_FP4 and fp4_triton:
424+
if elem_dtype != torch.float4_e2m1fn_x2 and fp4_triton:
426425
pytest.skip("unsupported configuration")
427426

428427
M, K = 128, 256

torchao/prototype/mx_formats/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MX training and inference with native PyTorch
22

3-
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
3+
This is a workflow for e2e training and inference with MX dtypes from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
44
in native PyTorch. We are currently in prototype and are actively working on optimizing these workflows on the NVIDIA B200 hardware.
55

66
## Overall status
@@ -34,8 +34,8 @@ gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
3434

3535
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
3636
config = MXLinearConfig(
37-
elem_dtype=torch.float8_e4m3fn,
38-
block_size=32,
37+
elem_dtype=torch.float8_e4m3fn,
38+
block_size=32,
3939
gemm_kernel_choice=gemm_kernel_choice,
4040
)
4141
quantize_(m, config)
@@ -55,8 +55,8 @@ from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelCh
5555
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
5656
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
5757
config = MXInferenceLinearConfig(
58-
elem_dtype=torch.float8_e4m3fn,
59-
block_size=32,
58+
elem_dtype=torch.float8_e4m3fn,
59+
block_size=32,
6060
gemm_kernel_choice=gemm_kernel_choice,
6161
)
6262
quantize_(m, config=config)
@@ -71,10 +71,10 @@ only `torch.float32` and `torch.bfloat16` are supported as high precision format
7171
```python
7272
from torchao.prototype.mx_formats.mx_tensor import MXTensor
7373
# Note: MX int8 is not implemented yet
74-
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
74+
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E2M3, DTYPE_FP6_E3M2
7575
x = torch.randn(32, 32, device='cuda')
7676

77-
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, DTYPE_FP4
77+
# elem_dtype can be torch.float8_e4m3fn, torch.float8_e5m2, DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, torch.float4_e2m1fn_x2
7878
elem_dtype = torch.float8_e4m3fn
7979

8080
# high precision to MX, block size defaults to 32
@@ -88,7 +88,7 @@ x_hp = x_mx.to_dtype(torch.float)
8888

8989
## mxfp8 gemm
9090

91-
On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
91+
On NVIDIA B200 machines, we use the cuBLAS mxfp8 gemm exposed via the `torch._scaled_mm` op.
9292
We observe a speedup of **2x to 3x** vs the bf16 baseline on common shapes. To reproduce this
9393
on supported hardware, you can run the following command:
9494

torchao/prototype/mx_formats/benchmarks/bench_qdq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from torchao.prototype.mx_formats import config
1919
from torchao.prototype.mx_formats.constants import ( # noqa: E501
20-
DTYPE_FP4,
2120
SUPPORTED_ELEM_DTYPES,
2221
)
2322
from torchao.prototype.mx_formats.mx_tensor import MXTensor
@@ -44,7 +43,8 @@ def run(profile_folder: Optional[str] = None):
4443
)
4544

4645
if (
47-
elem_dtype != DTYPE_FP4 and use_fp4_custom_triton_dequant_kernel # noqa: E501
46+
elem_dtype != torch.float4_e2m1fn_x2
47+
and use_fp4_custom_triton_dequant_kernel # noqa: E501
4848
):
4949
# custom_triton_kernels only works for fp4
5050
continue

torchao/prototype/mx_formats/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from torchao.core.config import AOBaseConfig
1414
from torchao.prototype.mx_formats.constants import (
15-
DTYPE_FP4,
1615
DTYPE_FP6_E2M3,
1716
DTYPE_FP6_E3M2,
1817
DTYPE_TO_SHORT_STR,
@@ -53,7 +52,7 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
5352
assert block_size == 32, (
5453
f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
5554
)
56-
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
55+
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
5756
assert elem_dtype in valid_dtypes, (
5857
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
5958
)
@@ -126,10 +125,11 @@ def from_recipe_name(
126125
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
127126
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
128127
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
129-
return MXLinearConfig(elem_dtype=DTYPE_FP4)
128+
return MXLinearConfig(elem_dtype=torch.float4_e2m1fn_x2)
130129
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
131130
return MXLinearConfig(
132-
elem_dtype=DTYPE_FP4, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS
131+
elem_dtype=torch.float4_e2m1fn_x2,
132+
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
133133
)
134134
else:
135135
raise AssertionError(f"unknown recipe_name {recipe_name}")

torchao/prototype/mx_formats/constants.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77

8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
9+
810
# This is conceptually an enum of non-core dtypes
911
# TODO(future PR): change to a cleaner way to represent this without
1012
# regressing torch.compile and while keeping things readable.
11-
DTYPE_FP4 = "fp4_e2m1"
1213
DTYPE_FP6_E3M2 = "fp6_e3m2"
1314
DTYPE_FP6_E2M3 = "fp6_e2m3"
1415

@@ -19,16 +20,21 @@
1920
torch.float8_e5m2,
2021
DTYPE_FP6_E2M3,
2122
DTYPE_FP6_E3M2,
22-
DTYPE_FP4,
2323
]
24+
SUPPORTED_ELEM_DTYPES = (
25+
SUPPORTED_ELEM_DTYPES + [torch.float4_e2m1fn_x2]
26+
if TORCH_VERSION_AT_LEAST_2_8
27+
else SUPPORTED_ELEM_DTYPES
28+
)
2429

2530
DTYPE_TO_SHORT_STR = {
2631
torch.float8_e4m3fn: "f8e4m3",
2732
torch.float8_e5m2: "f8e5m2",
2833
DTYPE_FP6_E2M3: "f6e2m3",
2934
DTYPE_FP6_E3M2: "f6e3m2",
30-
DTYPE_FP4: "f4e2m1",
3135
}
36+
if TORCH_VERSION_AT_LEAST_2_8:
37+
DTYPE_TO_SHORT_STR[torch.float4_e2m1fn_x2] = "f4e2m1"
3238

3339
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
3440
F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0

torchao/prototype/mx_formats/fp_format_spec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717

1818
from torchao.prototype.mx_formats.constants import (
19-
DTYPE_FP4,
2019
DTYPE_FP6_E2M3,
2120
DTYPE_FP6_E3M2,
2221
)
@@ -494,7 +493,7 @@ def run(dtype):
494493
headers = ["orig_val", "formula", "s_enc", "e_enc", "m_enc", "note"]
495494
results = []
496495

497-
if dtype == DTYPE_FP4:
496+
if dtype == torch.float4_e2m1fn_x2:
498497
results = float4_e2m1_interesting_values
499498
elif dtype == DTYPE_FP6_E3M2:
500499
results = float6_e3m2_interesting_values
@@ -539,6 +538,6 @@ def run(dtype):
539538
torch.float8_e5m2,
540539
DTYPE_FP6_E3M2,
541540
DTYPE_FP6_E2M3,
542-
DTYPE_FP4,
541+
torch.float4_e2m1fn_x2,
543542
):
544543
run(dtype)

0 commit comments

Comments
 (0)