Skip to content

Commit 3c9638d

Browse files
authored
Add kernel (#2439)
stack-info: PR: #2439, branch: drisspg/stack/81
1 parent c044ddb commit 3c9638d

File tree

6 files changed

+474
-35
lines changed

6 files changed

+474
-35
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torchao.utils import (
3838
TORCH_VERSION_AT_LEAST_2_8,
3939
is_sm_at_least_89,
40+
is_sm_at_least_90,
4041
is_sm_at_least_100,
4142
)
4243

@@ -459,10 +460,29 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
459460
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
460461
)
461462
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
463+
@pytest.mark.parametrize("use_triton_kernel", [True, False])
464+
@pytest.mark.parametrize(
465+
"shapes",
466+
[
467+
(128, 64, 256),
468+
(256, 128, 512),
469+
(145, 64, 256),
470+
(128, 96, 256),
471+
(128, 160, 256),
472+
(64, 64, 256),
473+
(200, 192, 256),
474+
],
475+
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
476+
)
462477
@torch.no_grad()
463478
@skip_if_rocm("ROCm float4 gemm require gfx950")
464479
def test_inference_subclass_nvfp4(
465-
bias: bool, compile: bool, mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype
480+
bias: bool,
481+
compile: bool,
482+
mm_config: NVFP4MMConfig,
483+
inpt_dtype: torch.dtype,
484+
use_triton_kernel: bool,
485+
shapes: tuple,
466486
):
467487
"""
468488
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
@@ -477,16 +497,20 @@ def test_inference_subclass_nvfp4(
477497

478498
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
479499
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
480-
m = nn.Linear(64, 256, bias=bias, dtype=inpt_dtype, device="cuda")
500+
batch_size, in_features, out_features = shapes
501+
502+
m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda")
481503
m_mx = copy.deepcopy(m)
482504

483-
config = NVFP4InferenceConfig(mm_config=mm_config)
505+
config = NVFP4InferenceConfig(
506+
mm_config=mm_config, use_triton_kernel=use_triton_kernel
507+
)
484508
quantize_(m_mx, config=config)
485509

486510
if compile:
487511
m_mx = torch.compile(m_mx, fullgraph=True, backend="aot_eager")
488512

489-
x = torch.randn(128, 64, device="cuda", dtype=inpt_dtype)
513+
x = torch.randn(batch_size, in_features, device="cuda", dtype=inpt_dtype)
490514
y_ref = m(x)
491515
y_mx = m_mx(x)
492516
sqnr = compute_error(y_ref, y_mx)
@@ -513,14 +537,33 @@ def test_inference_subclass_nvfp4(
513537
@pytest.mark.parametrize("compile", [False])
514538
@pytest.mark.parametrize("bias", [True, False])
515539
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
540+
@pytest.mark.parametrize("use_triton_kernel", [True, False])
541+
@pytest.mark.parametrize(
542+
"shapes",
543+
[
544+
(128, 64, 256),
545+
(256, 128, 512),
546+
(157, 64, 256),
547+
(128, 96, 256),
548+
(128, 160, 256),
549+
(64, 64, 256),
550+
(200, 192, 256),
551+
],
552+
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
553+
)
516554
@torch.no_grad()
517555
@skip_if_rocm("ROCm float4 gemm require gfx950")
556+
@pytest.mark.skipif(
557+
not is_sm_at_least_90(), reason="CUDA capability >= 9.0 required for fp8e4nv"
558+
)
518559
def test_nvfp4_matmul_with_amax(
519560
use_gelu: bool,
520561
mm_config: NVFP4MMConfig,
521562
compile: bool,
522563
bias: bool,
523564
inpt_dtype: torch.dtype,
565+
use_triton_kernel: bool,
566+
shapes: tuple,
524567
):
525568
from torchao.prototype.mx_formats.nvfp4_tensor import (
526569
NVFP4Tensor,
@@ -537,7 +580,7 @@ def test_nvfp4_matmul_with_amax(
537580
if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
538581
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
539582

540-
m, k, n = 64, 256, 128
583+
m, k, n = shapes
541584

542585
# Create activation tensor
543586
if use_gelu:
@@ -559,12 +602,14 @@ def test_nvfp4_matmul_with_amax(
559602
per_tensor_scale=a_scale,
560603
mm_config=mm_config,
561604
is_swizzled_scales=True,
605+
use_triton_kernel=use_triton_kernel,
562606
)
563607
B_nvfp4 = NVFP4Tensor.to_nvfp4(
564608
B,
565609
per_tensor_scale=b_scale,
566610
mm_config=mm_config,
567611
is_swizzled_scales=True,
612+
use_triton_kernel=use_triton_kernel,
568613
)
569614

570615
func = torch.compile(F.linear, fullgraph=True) if compile else F.linear

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchao.utils import (
2828
TORCH_VERSION_AT_LEAST_2_8,
2929
is_sm_at_least_89,
30+
is_sm_at_least_100,
3031
)
3132

3233
torch.manual_seed(2)
@@ -955,3 +956,69 @@ def test_nvfp4_swizzled_scales_get_scales_method():
955956
expected_shape = (M, K // 16)
956957
assert regular_scales.shape == expected_shape
957958
assert swizzled_scales.shape == expected_shape
959+
960+
961+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
962+
@pytest.mark.parametrize(
963+
"M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}"
964+
)
965+
@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}")
966+
@pytest.mark.parametrize(
967+
"use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"]
968+
)
969+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
970+
@pytest.mark.skipif(
971+
not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics"
972+
)
973+
@torch.no_grad()
974+
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
975+
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
976+
from torchao.prototype.mx_formats.nvfp4_tensor import (
977+
NVFP4Tensor,
978+
per_tensor_amax_to_scale,
979+
unpack_uint4,
980+
)
981+
982+
torch.manual_seed(42)
983+
x = torch.randn(M, N, dtype=dtype, device="cuda")
984+
985+
per_tensor_scale = None
986+
if use_per_tensor_scale:
987+
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x)))
988+
989+
nvfp4_pt = NVFP4Tensor.to_nvfp4(
990+
x.clone(),
991+
per_tensor_scale=per_tensor_scale,
992+
is_swizzled_scales=True,
993+
use_triton_kernel=False,
994+
)
995+
996+
nvfp4_triton = NVFP4Tensor.to_nvfp4(
997+
x.clone(),
998+
per_tensor_scale=per_tensor_scale,
999+
is_swizzled_scales=True,
1000+
use_triton_kernel=True,
1001+
)
1002+
1003+
torch.testing.assert_close(
1004+
nvfp4_pt._scale_e4m3.flatten(), nvfp4_triton._scale_e4m3.flatten()
1005+
)
1006+
pt_unpacked = unpack_uint4(nvfp4_pt._data)
1007+
triton_unpacked = unpack_uint4(nvfp4_triton._data)
1008+
torch.testing.assert_close(
1009+
pt_unpacked,
1010+
triton_unpacked,
1011+
atol=0,
1012+
rtol=0,
1013+
)
1014+
1015+
x_pt_dequant = nvfp4_pt.to_dtype(dtype)
1016+
x_triton_dequant = nvfp4_triton.to_dtype(dtype)
1017+
1018+
sqnr = compute_error(x_pt_dequant, x_triton_dequant)
1019+
SQNR_THRESHOLD = 40.0
1020+
1021+
assert sqnr >= SQNR_THRESHOLD, (
1022+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} for M={M}, N={N}, "
1023+
f"use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}"
1024+
)

0 commit comments

Comments
 (0)