Skip to content

Commit 0065bcd

Browse files
Add CUDA kernel for MXFP8 dim1 casting
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
1 parent ddd4021 commit 0065bcd

File tree

9 files changed

+2274
-19
lines changed

9 files changed

+2274
-19
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Tuple
7+
from typing import Tuple
88

99
import fire
1010
import torch
1111
import triton
12-
from torch._inductor.utils import do_bench_using_profiling
12+
from triton.testing import do_bench
1313

1414
from torchao.prototype.mx_formats.kernels import (
1515
triton_to_mxfp8_dim1,
1616
)
1717
from torchao.prototype.mx_formats.mx_tensor import to_mx
1818

19+
try:
20+
import mxfp8_cuda
21+
except ImportError:
22+
print(
23+
"Warning: mxfp8_cuda extension not found or ready. Benchmarks using this will not be able to run."
24+
)
25+
1926
torch.manual_seed(0)
2027

2128
bytes_per_el_bf16 = 2
@@ -64,29 +71,35 @@ def to_mx_dim1_reference(x_hp, block_size):
6471
return data_d1.t(), scale_d1
6572

6673

67-
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
68-
"""Thin wrapper around do_bench_using_profiling"""
69-
no_args = lambda: func(*args, **kwargs)
70-
time = do_bench_using_profiling(no_args)
71-
return time * 1e3
74+
def benchmark_cuda_function_in_microseconds(f, *args):
75+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
7276

7377

7478
def run(
7579
M: int = 16384,
7680
K: int = 16384,
7781
BLOCK_SIZE: int = 32,
78-
mode: str = "dim0",
82+
mode: str = "dim0_floor",
7983
):
8084
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
8185
print(f"GPU: {torch.cuda.get_device_name(0)}")
8286
print(f"torch version: {torch.__version__}")
8387
print(f"triton version: {triton.__version__}")
8488
print(f"mode: {mode}")
85-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
89+
assert mode in (
90+
"dim0_floor",
91+
"dim1_floor",
92+
"dim0_dim1_floor",
93+
"dim0_mx_floor",
94+
"dim1_mx_floor",
95+
"dim1_mx_triton",
96+
"dim1_mx_cuda_floor",
97+
"dim1_mx_cuda_rceil",
98+
)
8699

87100
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
88101

89-
if mode == "dim0":
102+
if mode == "dim0_floor":
90103
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
91104
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
92105

@@ -103,7 +116,7 @@ def run(
103116
bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16
104117
bps = bytes_rw / (time_us / 1e6)
105118

106-
elif mode == "dim1":
119+
elif mode == "dim1_floor":
107120
scale_dim1_reference_c = torch.compile(scale_dim1_reference)
108121
y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE)
109122

@@ -120,7 +133,7 @@ def run(
120133
bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16
121134
bps = bytes_rw / (time_us / 1e6)
122135

123-
elif mode == "dim0_dim1":
136+
elif mode == "dim0_dim1_floor":
124137
scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference)
125138
y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
126139

@@ -141,7 +154,7 @@ def run(
141154
)
142155
bps = bytes_rw / (time_us / 1e6)
143156

144-
elif mode == "dim0_mx":
157+
elif mode == "dim0_mx_floor":
145158
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
146159
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)
147160

@@ -159,7 +172,7 @@ def run(
159172
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
160173
bps = (bytes_r + bytes_w) / (time_us / 1e6)
161174

162-
elif mode == "dim1_mx":
175+
elif mode == "dim1_mx_floor":
163176
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
164177
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
165178

@@ -194,6 +207,54 @@ def run(
194207
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
195208
bps = (bytes_r + bytes_w) / (time_us / 1e6)
196209

210+
elif mode == "dim1_mx_cuda_floor":
211+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
212+
x, rowwise=False, colwise=True, scaling_mode="floor"
213+
)
214+
215+
for _ in range(2):
216+
__ = mxfp8_cuda.quantize(
217+
x, rowwise=False, colwise=True, scaling_mode="floor"
218+
)
219+
220+
time_us = benchmark_cuda_function_in_microseconds(
221+
lambda x: mxfp8_cuda.quantize(
222+
x, rowwise=False, colwise=True, scaling_mode="floor"
223+
),
224+
x,
225+
)
226+
227+
assert y_d1.dtype == torch.float8_e4m3fn
228+
assert s_d1.dtype == torch.float8_e8m0fnu
229+
230+
bytes_r = x.numel() * bytes_per_el_bf16
231+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
232+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
233+
234+
elif mode == "dim1_mx_cuda_rceil":
235+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
236+
x, rowwise=False, colwise=True, scaling_mode="rceil"
237+
)
238+
239+
for _ in range(2):
240+
__ = mxfp8_cuda.quantize(
241+
x, rowwise=False, colwise=True, scaling_mode="rceil"
242+
)
243+
244+
time_us = benchmark_cuda_function_in_microseconds(
245+
lambda x: mxfp8_cuda.quantize(
246+
x, rowwise=False, colwise=True, scaling_mode="rceil"
247+
),
248+
x,
249+
)
250+
251+
assert y_d1.dtype == torch.float8_e4m3fn
252+
assert s_d1.dtype == torch.float8_e8m0fnu
253+
254+
bytes_r = x.numel() * bytes_per_el_bf16
255+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
256+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
257+
197258
else:
198259
raise AssertionError(f"unknown mode {mode}")
199260

test/prototype/mx_formats/test_kernels.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,37 @@
4242
triton_to_mxfp8_dim1_reference,
4343
unpack_uint4,
4444
)
45-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx
4646
from torchao.prototype.mx_formats.utils import to_blocked
4747
from torchao.utils import (
4848
TORCH_VERSION_AT_LEAST_2_8,
4949
is_sm_at_least_89,
5050
is_sm_at_least_100,
5151
)
5252

53+
try:
54+
import mxfp8_cuda
55+
except ImportError:
56+
print("Warning: MXFP8 CUDA extension not available, some tests will be skipped")
57+
pass
58+
5359
torch.manual_seed(0)
5460

5561
if not TORCH_VERSION_AT_LEAST_2_8:
5662
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
5763

5864

65+
# TODO: shared utils file for benchmarking and testing
66+
def to_mx_dim1_reference(x_hp, block_size, scaling_mode):
67+
x_hp = x_hp.t().contiguous()
68+
scale_d1, data_d1 = to_mx(
69+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
70+
)
71+
return data_d1.t(), scale_d1.squeeze(
72+
-1
73+
) # torchao impl returns extra empty dim that triton / cuda do not
74+
75+
5976
@pytest.mark.skip(
6077
reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501
6178
)
@@ -488,3 +505,105 @@ def test_rearrange(shape):
488505
eager = to_blocked(scales, False)
489506
triton = to_blocked(scales, True)
490507
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
508+
509+
510+
@pytest.mark.skipif(
511+
not is_sm_at_least_100(),
512+
reason="MXFP8 requires CUDA capability 10.0 or greater",
513+
)
514+
@pytest.mark.skipif(
515+
"mxfp8_cuda" not in globals(),
516+
reason="mxfp8p_cuda extnesion not available",
517+
)
518+
@pytest.mark.parametrize("M", (32, 64, 2048))
519+
@pytest.mark.parametrize("K", (32, 64, 2048))
520+
@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16))
521+
@pytest.mark.parametrize(
522+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
523+
)
524+
def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
525+
scaling_mode_str = (
526+
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
527+
)
528+
block_size = 32
529+
530+
# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
531+
x = (
532+
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
533+
.reshape(M, K)
534+
.contiguous()
535+
)
536+
537+
y_d1_ref, s_d1_ref = to_mx_dim1_reference(
538+
x,
539+
block_size=block_size,
540+
scaling_mode=scaling_mode,
541+
)
542+
543+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
544+
x,
545+
rowwise=False,
546+
colwise=True,
547+
scaling_mode=scaling_mode_str,
548+
scale_dim_x=1,
549+
scale_dim_y=block_size,
550+
)
551+
552+
# check scales
553+
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
554+
555+
# check quantized values
556+
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
557+
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
558+
559+
560+
@pytest.mark.skipif(
561+
not is_sm_at_least_100(),
562+
reason="MXFP8 requires CUDA capability 10.0 or greater",
563+
)
564+
@pytest.mark.skipif(
565+
"mxfp8_cuda" not in globals(),
566+
reason="mxfp8p_cuda extnesion not available",
567+
)
568+
def test_cuda_mx_dim0_not_supported():
569+
M, K = 64, 64
570+
block_size = 32
571+
x = (
572+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
573+
.reshape(M, K)
574+
.contiguous()
575+
)
576+
with pytest.raises(RuntimeError):
577+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
578+
x,
579+
rowwise=True,
580+
colwise=False,
581+
scale_dim_x=block_size,
582+
scale_dim_y=1,
583+
)
584+
585+
586+
@pytest.mark.skipif(
587+
not is_sm_at_least_100(),
588+
reason="MXFP8 requires CUDA capability 10.0 or greater",
589+
)
590+
@pytest.mark.skipif(
591+
"mxfp8_cuda" not in globals(),
592+
reason="mxfp8p_cuda extnesion not available",
593+
)
594+
def test_cuda_mx_dim1_invalid_block_size():
595+
M, K = 64, 64
596+
x = (
597+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
598+
.reshape(M, K)
599+
.contiguous()
600+
)
601+
invalid_block_size = 4
602+
with pytest.raises(RuntimeError):
603+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
604+
x,
605+
rowwise=False,
606+
colwise=True,
607+
scale_dim_x=1,
608+
scale_dim_y=invalid_block_size,
609+
)

torchao/prototype/mx_formats/kernels.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,17 +1375,21 @@ def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
13751375
return acceptable_shardings
13761376

13771377
def triton_to_mxfp8_dim1_reference(
1378-
x_hp: torch.Tensor, block_size
1378+
x_hp: torch.Tensor,
1379+
block_size,
1380+
scaling_mode="FLOOR",
13791381
) -> Tuple[torch.Tensor, torch.Tensor]:
13801382
"""
13811383
A reference version of `to_mxfp8_dim1`.
13821384
"""
1383-
from torchao.prototype.mx_formats.mx_tensor import to_mx
1385+
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
1386+
1387+
scale_mode = ScaleCalculationMode[scaling_mode]
13841388

13851389
# cast across dim1
13861390
x_hp_d1 = x_hp.t().contiguous()
13871391
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
1388-
x_hp_d1, torch.float8_e4m3fn, block_size
1392+
x_hp_d1, torch.float8_e4m3fn, block_size, scaling_mode=scale_mode
13891393
)
13901394
scale_e8m0_dim1 = scale_e8m0_dim1.view(torch.float8_e8m0fnu)
13911395
return (
@@ -1718,7 +1722,7 @@ def triton_to_mxfp8_dim1(
17181722
raise AssertionError("needs torch version 2.8+ and triton")
17191723

17201724
def triton_to_mxfp8_dim1_reference(
1721-
x_hp: torch.Tensor, block_size
1725+
x_hp: torch.Tensor, block_size, scaling_mode
17221726
) -> Tuple[torch.Tensor, torch.Tensor]:
17231727
raise AssertionError("needs torch version 2.8+ and triton")
17241728

0 commit comments

Comments
 (0)