Skip to content

Commit 4823b37

Browse files
Add CUDA kernel for MXFP8 dim1 casting
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
1 parent c1e84cc commit 4823b37

File tree

9 files changed

+2244
-13
lines changed

9 files changed

+2244
-13
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,26 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Tuple
7+
from functools import partial
8+
from typing import Tuple
89

910
import fire
1011
import torch
1112
import triton
12-
from torch._inductor.utils import do_bench_using_profiling
13+
from triton.testing import do_bench
1314

1415
from torchao.prototype.mx_formats.kernels import (
1516
triton_to_mxfp8_dim1,
1617
)
1718
from torchao.prototype.mx_formats.mx_tensor import to_mx
1819

20+
try:
21+
import mxfp8_cuda
22+
except ImportError:
23+
print(
24+
"Warning: mxfp8_cuda extension not found or ready. Benchmarks using this will not be able to run."
25+
)
26+
1927
torch.manual_seed(0)
2028

2129
bytes_per_el_bf16 = 2
@@ -64,11 +72,8 @@ def to_mx_dim1_reference(x_hp, block_size):
6472
return data_d1.t(), scale_d1
6573

6674

67-
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
68-
"""Thin wrapper around do_bench_using_profiling"""
69-
no_args = lambda: func(*args, **kwargs)
70-
time = do_bench_using_profiling(no_args)
71-
return time * 1e3
75+
def benchmark_cuda_function_in_microseconds(f, *args):
76+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
7277

7378

7479
def run(
@@ -82,7 +87,16 @@ def run(
8287
print(f"torch version: {torch.__version__}")
8388
print(f"triton version: {triton.__version__}")
8489
print(f"mode: {mode}")
85-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
90+
assert mode in (
91+
"dim0",
92+
"dim1",
93+
"dim0_dim1",
94+
"dim0_mx",
95+
"dim1_mx",
96+
"dim1_mx_triton",
97+
"dim1_mx_cuda_floor",
98+
"dim1_mx_cuda_rceil",
99+
)
86100

87101
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
88102

@@ -194,6 +208,42 @@ def run(
194208
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
195209
bps = (bytes_r + bytes_w) / (time_us / 1e6)
196210

211+
elif mode == "dim1_mx_cuda_floor":
212+
bench_fn = partial(
213+
mxfp8_cuda.quantize, rowwise=False, colwise=True, scaling_mode="floor"
214+
)
215+
_, y_d1, _, s_d1 = bench_fn(x)
216+
217+
for _ in range(2):
218+
__ = bench_fn(x)
219+
220+
time_us = benchmark_cuda_function_in_microseconds(bench_fn, x)
221+
222+
assert y_d1.dtype == torch.float8_e4m3fn
223+
assert s_d1.dtype == torch.float8_e8m0fnu
224+
225+
bytes_r = x.numel() * bytes_per_el_bf16
226+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
227+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
228+
229+
elif mode == "dim1_mx_cuda_rceil":
230+
bench_fn = partial(
231+
mxfp8_cuda.quantize, rowwise=False, colwise=True, scaling_mode="rceil"
232+
)
233+
_, y_d1, _, s_d1 = bench_fn(x)
234+
235+
for _ in range(2):
236+
__ = bench_fn(x)
237+
238+
time_us = benchmark_cuda_function_in_microseconds(bench_fn, x)
239+
240+
assert y_d1.dtype == torch.float8_e4m3fn
241+
assert s_d1.dtype == torch.float8_e8m0fnu
242+
243+
bytes_r = x.numel() * bytes_per_el_bf16
244+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
245+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
246+
197247
else:
198248
raise AssertionError(f"unknown mode {mode}")
199249

test/prototype/mx_formats/test_kernels.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,20 @@
4242
triton_to_mxfp8_dim1_reference,
4343
unpack_uint4,
4444
)
45-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode
4646
from torchao.prototype.mx_formats.utils import to_blocked
4747
from torchao.utils import (
4848
TORCH_VERSION_AT_LEAST_2_8,
4949
is_sm_at_least_89,
5050
is_sm_at_least_100,
5151
)
5252

53+
try:
54+
import mxfp8_cuda
55+
except ImportError:
56+
print("Warning: MXFP8 CUDA extension not available, some tests will be skipped")
57+
pass
58+
5359
torch.manual_seed(0)
5460

5561
if not TORCH_VERSION_AT_LEAST_2_8:
@@ -488,3 +494,101 @@ def test_rearrange(shape):
488494
eager = to_blocked(scales, False)
489495
triton = to_blocked(scales, True)
490496
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
497+
498+
499+
@pytest.mark.skipif(
500+
not is_sm_at_least_100(),
501+
reason="MXFP8 requires CUDA capability 10.0 or greater",
502+
)
503+
@pytest.mark.skipif(
504+
"mxfp8_cuda" not in globals(),
505+
reason="mxfp8p_cuda extnesion not available",
506+
)
507+
@pytest.mark.parametrize("M", (32, 64, 2048))
508+
@pytest.mark.parametrize("K", (32, 64, 2048))
509+
@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16))
510+
@pytest.mark.parametrize(
511+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
512+
)
513+
def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
514+
scaling_mode_str = (
515+
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
516+
)
517+
block_size = 32
518+
519+
# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
520+
x = (
521+
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
522+
.reshape(M, K)
523+
.contiguous()
524+
)
525+
526+
y_d1_ref, s_d1_ref = triton_to_mxfp8_dim1_reference(
527+
x, block_size=block_size, scaling_mode=scaling_mode
528+
)
529+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
530+
x,
531+
rowwise=False,
532+
colwise=True,
533+
scaling_mode=scaling_mode_str,
534+
scale_dim_x=1,
535+
scale_dim_y=block_size,
536+
)
537+
538+
# check scales
539+
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
540+
541+
# check quantized values
542+
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
543+
544+
545+
@pytest.mark.skipif(
546+
not is_sm_at_least_100(),
547+
reason="MXFP8 requires CUDA capability 10.0 or greater",
548+
)
549+
@pytest.mark.skipif(
550+
"mxfp8_cuda" not in globals(),
551+
reason="mxfp8p_cuda extnesion not available",
552+
)
553+
def test_cuda_mx_dim0_not_supported():
554+
M, K = 64, 64
555+
block_size = 32
556+
x = (
557+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
558+
.reshape(M, K)
559+
.contiguous()
560+
)
561+
with pytest.raises(RuntimeError):
562+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
563+
x,
564+
rowwise=True,
565+
colwise=False,
566+
scale_dim_x=block_size,
567+
scale_dim_y=1,
568+
)
569+
570+
571+
@pytest.mark.skipif(
572+
not is_sm_at_least_100(),
573+
reason="MXFP8 requires CUDA capability 10.0 or greater",
574+
)
575+
@pytest.mark.skipif(
576+
"mxfp8_cuda" not in globals(),
577+
reason="mxfp8p_cuda extnesion not available",
578+
)
579+
def test_cuda_mx_dim1_invalid_block_size():
580+
M, K = 64, 64
581+
x = (
582+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
583+
.reshape(M, K)
584+
.contiguous()
585+
)
586+
invalid_block_size = 4
587+
with pytest.raises(RuntimeError):
588+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
589+
x,
590+
rowwise=False,
591+
colwise=True,
592+
scale_dim_x=1,
593+
scale_dim_y=invalid_block_size,
594+
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,17 +1375,23 @@ def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
13751375
return acceptable_shardings
13761376

13771377
def triton_to_mxfp8_dim1_reference(
1378-
x_hp: torch.Tensor, block_size
1378+
x_hp: torch.Tensor,
1379+
block_size,
1380+
scaling_mode=None,
13791381
) -> Tuple[torch.Tensor, torch.Tensor]:
13801382
"""
13811383
A reference version of `to_mxfp8_dim1`.
13821384
"""
1383-
from torchao.prototype.mx_formats.mx_tensor import to_mx
1385+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
1386+
1387+
scale_mode = (
1388+
ScaleCalculationMode.FLOOR if scaling_mode is None else scaling_mode
1389+
)
13841390

13851391
# cast across dim1
13861392
x_hp_d1 = x_hp.t().contiguous()
13871393
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
1388-
x_hp_d1, torch.float8_e4m3fn, block_size
1394+
x_hp_d1, torch.float8_e4m3fn, block_size, scaling_mode=scale_mode
13891395
)
13901396
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
13911397
return (
@@ -1718,7 +1724,7 @@ def triton_to_mxfp8_dim1(
17181724
raise AssertionError("needs torch version 2.8+ and triton")
17191725

17201726
def triton_to_mxfp8_dim1_reference(
1721-
x_hp: torch.Tensor, block_size
1727+
x_hp: torch.Tensor, block_size, scaling_mode
17221728
) -> Tuple[torch.Tensor, torch.Tensor]:
17231729
raise AssertionError("needs torch version 2.8+ and triton")
17241730

0 commit comments

Comments
 (0)