Skip to content

Commit 97c77a7

Browse files
Add CUDA kernel for MXFP8 dim1 casting
Co-authored-by: Less Wright <lessw@etrillium.com> stack-info: PR: #2513, branch: danielvegamyhre/stack/3
1 parent ddd4021 commit 97c77a7

File tree

8 files changed

+1804
-20
lines changed

8 files changed

+1804
-20
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Callable, Tuple
7+
from typing import Tuple
88

99
import fire
1010
import torch
1111
import triton
12-
from torch._inductor.utils import do_bench_using_profiling
12+
from triton.testing import do_bench
1313

1414
from torchao.prototype.mx_formats.kernels import (
1515
triton_to_mxfp8_dim1,
@@ -64,29 +64,35 @@ def to_mx_dim1_reference(x_hp, block_size):
6464
return data_d1.t(), scale_d1
6565

6666

67-
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
68-
"""Thin wrapper around do_bench_using_profiling"""
69-
no_args = lambda: func(*args, **kwargs)
70-
time = do_bench_using_profiling(no_args)
71-
return time * 1e3
67+
def benchmark_cuda_function_in_microseconds(f, *args):
68+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
7269

7370

7471
def run(
7572
M: int = 16384,
7673
K: int = 16384,
7774
BLOCK_SIZE: int = 32,
78-
mode: str = "dim0",
75+
mode: str = "dim0_floor",
7976
):
8077
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
8178
print(f"GPU: {torch.cuda.get_device_name(0)}")
8279
print(f"torch version: {torch.__version__}")
8380
print(f"triton version: {triton.__version__}")
8481
print(f"mode: {mode}")
85-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
82+
assert mode in (
83+
"dim0_floor",
84+
"dim1_floor",
85+
"dim0_dim1_floor",
86+
"dim0_mx_floor",
87+
"dim1_mx_floor",
88+
"dim1_mx_triton_floor",
89+
"dim1_mx_cuda_floor",
90+
"dim1_mx_cuda_rceil",
91+
)
8692

8793
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
8894

89-
if mode == "dim0":
95+
if mode == "dim0_floor":
9096
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
9197
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
9298

@@ -103,7 +109,7 @@ def run(
103109
bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16
104110
bps = bytes_rw / (time_us / 1e6)
105111

106-
elif mode == "dim1":
112+
elif mode == "dim1_floor":
107113
scale_dim1_reference_c = torch.compile(scale_dim1_reference)
108114
y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE)
109115

@@ -120,7 +126,7 @@ def run(
120126
bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16
121127
bps = bytes_rw / (time_us / 1e6)
122128

123-
elif mode == "dim0_dim1":
129+
elif mode == "dim0_dim1_floor":
124130
scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference)
125131
y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
126132

@@ -141,7 +147,7 @@ def run(
141147
)
142148
bps = bytes_rw / (time_us / 1e6)
143149

144-
elif mode == "dim0_mx":
150+
elif mode == "dim0_mx_floor":
145151
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
146152
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)
147153

@@ -159,7 +165,7 @@ def run(
159165
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
160166
bps = (bytes_r + bytes_w) / (time_us / 1e6)
161167

162-
elif mode == "dim1_mx":
168+
elif mode == "dim1_mx_floor":
163169
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
164170
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
165171

@@ -177,7 +183,7 @@ def run(
177183
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
178184
bps = (bytes_r + bytes_w) / (time_us / 1e6)
179185

180-
elif mode == "dim1_mx_triton":
186+
elif mode == "dim1_mx_triton_floor":
181187
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
182188

183189
for _ in range(2):
@@ -194,6 +200,58 @@ def run(
194200
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
195201
bps = (bytes_r + bytes_w) / (time_us / 1e6)
196202

203+
elif mode == "dim1_mx_cuda_floor":
204+
from torchao.prototype import mxfp8_cuda
205+
206+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
207+
x, rowwise=False, colwise=True, scaling_mode="floor"
208+
)
209+
210+
for _ in range(2):
211+
__ = mxfp8_cuda.quantize(
212+
x, rowwise=False, colwise=True, scaling_mode="floor"
213+
)
214+
215+
time_us = benchmark_cuda_function_in_microseconds(
216+
lambda x: mxfp8_cuda.quantize(
217+
x, rowwise=False, colwise=True, scaling_mode="floor"
218+
),
219+
x,
220+
)
221+
222+
assert y_d1.dtype == torch.float8_e4m3fn
223+
assert s_d1.dtype == torch.float8_e8m0fnu
224+
225+
bytes_r = x.numel() * bytes_per_el_bf16
226+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
227+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
228+
229+
elif mode == "dim1_mx_cuda_rceil":
230+
from torchao.prototype import mxfp8_cuda
231+
232+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
233+
x, rowwise=False, colwise=True, scaling_mode="rceil"
234+
)
235+
236+
for _ in range(2):
237+
__ = mxfp8_cuda.quantize(
238+
x, rowwise=False, colwise=True, scaling_mode="rceil"
239+
)
240+
241+
time_us = benchmark_cuda_function_in_microseconds(
242+
lambda x: mxfp8_cuda.quantize(
243+
x, rowwise=False, colwise=True, scaling_mode="rceil"
244+
),
245+
x,
246+
)
247+
248+
assert y_d1.dtype == torch.float8_e4m3fn
249+
assert s_d1.dtype == torch.float8_e8m0fnu
250+
251+
bytes_r = x.numel() * bytes_per_el_bf16
252+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
253+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
254+
197255
else:
198256
raise AssertionError(f"unknown mode {mode}")
199257

setup.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,14 @@ def get_extensions():
490490
if use_cuda:
491491
sources += cuda_sources
492492

493+
# Add MXFP8 cuda extension dir
494+
mxfp8_extension_dir = os.path.join(extensions_dir, "cuda", "mx_kernels")
495+
mxfp8_sources_to_exclude = list(
496+
glob.glob(os.path.join(mxfp8_extension_dir, "**/*"), recursive=True)
497+
)
498+
sources = [s for s in sources if s not in mxfp8_sources_to_exclude]
499+
print("sources after mxfp8 exclusion", sources)
500+
493501
# TOOD: Remove this and use what CUDA has once we fix all the builds.
494502
if use_rocm:
495503
# Add ROCm GPU architecture check
@@ -610,6 +618,36 @@ def get_extensions():
610618
)
611619
)
612620

621+
# Add the mxfp8 casting CUDA extension
622+
if use_cuda:
623+
mxfp8_sources = [
624+
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
625+
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
626+
]
627+
628+
# Only add the extension if the source files exist AND we are building for sm100
629+
mxfp8_src_files_exist = all(os.path.exists(f) for f in mxfp8_sources)
630+
if mxfp8_src_files_exist and build_for_sm100a:
631+
print("Building mxfp8_cuda extension")
632+
ext_modules.append(
633+
CUDAExtension(
634+
name="torchao.prototype.mxfp8_cuda",
635+
sources=mxfp8_sources,
636+
include_dirs=[
637+
mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu
638+
"/usr/local/cuda-12.8/include", # CUDA 12.8 headers
639+
],
640+
library_dirs=[
641+
"/usr/local/cuda-12.8/lib64", # CUDA 12.8 libraries
642+
],
643+
extra_compile_args={
644+
"cxx": ["-std=c++17", "-O3"],
645+
"nvcc": nvcc_args,
646+
},
647+
extra_link_args=["-lcuda", "-lcudart"],
648+
),
649+
)
650+
613651
# Only build the cutlass_90a extension if sm90a is in the architecture flags
614652
if (
615653
cutlass_90a_sources is not None

test/prototype/mx_formats/test_kernels.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
triton_to_mxfp8_dim1_reference,
4343
unpack_uint4,
4444
)
45-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx
4646
from torchao.prototype.mx_formats.utils import to_blocked
4747
from torchao.utils import (
4848
TORCH_VERSION_AT_LEAST_2_8,
@@ -56,6 +56,15 @@
5656
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
5757

5858

59+
# TODO: shared utils file for benchmarking and testing
60+
def to_mx_dim1_reference(x_hp, block_size, scaling_mode):
61+
x_hp = x_hp.t().contiguous()
62+
scale_d1, data_d1 = to_mx(
63+
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
64+
)
65+
return data_d1.t(), scale_d1
66+
67+
5968
@pytest.mark.skip(
6069
reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501
6170
)
@@ -488,3 +497,99 @@ def test_rearrange(shape):
488497
eager = to_blocked(scales, False)
489498
triton = to_blocked(scales, True)
490499
torch.testing.assert_close(eager, triton, atol=0, rtol=0)
500+
501+
502+
@pytest.mark.skipif(
503+
not is_sm_at_least_100(),
504+
reason="MXFP8 requires CUDA capability 10.0 or greater",
505+
)
506+
@pytest.mark.parametrize("M", (32, 64, 2048))
507+
@pytest.mark.parametrize("K", (32, 64, 2048))
508+
@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16))
509+
@pytest.mark.parametrize(
510+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
511+
)
512+
def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
513+
from torchao.prototype import mxfp8_cuda
514+
515+
scaling_mode_str = (
516+
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
517+
)
518+
block_size = 32
519+
520+
# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
521+
x = (
522+
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
523+
.reshape(M, K)
524+
.contiguous()
525+
)
526+
527+
y_d1_ref, s_d1_ref = to_mx_dim1_reference(
528+
x,
529+
block_size=block_size,
530+
scaling_mode=scaling_mode,
531+
)
532+
533+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
534+
x,
535+
rowwise=False,
536+
colwise=True,
537+
scaling_mode=scaling_mode_str,
538+
scale_dim_x=1,
539+
scale_dim_y=block_size,
540+
)
541+
542+
# check scales
543+
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
544+
545+
# check quantized values
546+
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
547+
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
548+
549+
550+
@pytest.mark.skipif(
551+
not is_sm_at_least_100(),
552+
reason="MXFP8 requires CUDA capability 10.0 or greater",
553+
)
554+
def test_cuda_mx_dim0_not_supported():
555+
from torchao.prototype import mxfp8_cuda
556+
557+
M, K = 64, 64
558+
block_size = 32
559+
x = (
560+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
561+
.reshape(M, K)
562+
.contiguous()
563+
)
564+
with pytest.raises(RuntimeError):
565+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
566+
x,
567+
rowwise=True,
568+
colwise=False,
569+
scale_dim_x=block_size,
570+
scale_dim_y=1,
571+
)
572+
573+
574+
@pytest.mark.skipif(
575+
not is_sm_at_least_100(),
576+
reason="MXFP8 requires CUDA capability 10.0 or greater",
577+
)
578+
def test_cuda_mx_dim1_invalid_block_size():
579+
from torchao.prototype import mxfp8_cuda
580+
581+
M, K = 64, 64
582+
x = (
583+
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
584+
.reshape(M, K)
585+
.contiguous()
586+
)
587+
invalid_block_size = 4
588+
with pytest.raises(RuntimeError):
589+
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
590+
x,
591+
rowwise=False,
592+
colwise=True,
593+
scale_dim_x=1,
594+
scale_dim_y=invalid_block_size,
595+
)

0 commit comments

Comments
 (0)