Skip to content

Add CUDA kernel for MXFP8 dim1 casting #2513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Tuple
from typing import Tuple

import fire
import torch
import triton
from torch._inductor.utils import do_bench_using_profiling
from triton.testing import do_bench

from torchao.prototype.mx_formats.kernels import (
triton_to_mxfp8_dim1,
Expand Down Expand Up @@ -64,29 +64,35 @@ def to_mx_dim1_reference(x_hp, block_size):
return data_d1.t(), scale_d1


def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
def benchmark_cuda_function_in_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def run(
M: int = 16384,
K: int = 16384,
BLOCK_SIZE: int = 32,
mode: str = "dim0",
mode: str = "dim0_floor",
):
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"torch version: {torch.__version__}")
print(f"triton version: {triton.__version__}")
print(f"mode: {mode}")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
assert mode in (
"dim0_floor",
"dim1_floor",
"dim0_dim1_floor",
"dim0_mx_floor",
"dim1_mx_floor",
"dim1_mx_triton_floor",
"dim1_mx_cuda_floor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "floor" to match all the others, or add "floor" to all the others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added floor to others, I like the more explicit naming.

"dim1_mx_cuda_rceil",
)

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000

if mode == "dim0":
if mode == "dim0_floor":
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)

Expand All @@ -103,7 +109,7 @@ def run(
bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16
bps = bytes_rw / (time_us / 1e6)

elif mode == "dim1":
elif mode == "dim1_floor":
scale_dim1_reference_c = torch.compile(scale_dim1_reference)
y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE)

Expand All @@ -120,7 +126,7 @@ def run(
bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16
bps = bytes_rw / (time_us / 1e6)

elif mode == "dim0_dim1":
elif mode == "dim0_dim1_floor":
scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference)
y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)

Expand All @@ -141,7 +147,7 @@ def run(
)
bps = bytes_rw / (time_us / 1e6)

elif mode == "dim0_mx":
elif mode == "dim0_mx_floor":
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)

Expand All @@ -159,7 +165,7 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx":
elif mode == "dim1_mx_floor":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

Expand All @@ -177,7 +183,7 @@ def run(
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_triton":
elif mode == "dim1_mx_triton_floor":
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
Expand All @@ -194,6 +200,58 @@ def run(
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_cuda_floor":
from torchao.prototype import mxfp8_cuda

_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="floor"
)

for _ in range(2):
__ = mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="floor"
)

time_us = benchmark_cuda_function_in_microseconds(
lambda x: mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="floor"
),
x,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.float8_e8m0fnu

bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_cuda_rceil":
from torchao.prototype import mxfp8_cuda

_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="rceil"
)

for _ in range(2):
__ = mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="rceil"
)

time_us = benchmark_cuda_function_in_microseconds(
lambda x: mxfp8_cuda.quantize(
x, rowwise=False, colwise=True, scaling_mode="rceil"
),
x,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.float8_e8m0fnu

bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

else:
raise AssertionError(f"unknown mode {mode}")

Expand Down
38 changes: 38 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ def get_extensions():
if use_cuda:
sources += cuda_sources

# Add MXFP8 cuda extension dir
mxfp8_extension_dir = os.path.join(extensions_dir, "cuda", "mx_kernels")
mxfp8_sources_to_exclude = list(
glob.glob(os.path.join(mxfp8_extension_dir, "**/*"), recursive=True)
)
sources = [s for s in sources if s not in mxfp8_sources_to_exclude]
print("sources after mxfp8 exclusion", sources)

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if use_rocm:
# Add ROCm GPU architecture check
Expand Down Expand Up @@ -610,6 +618,36 @@ def get_extensions():
)
)

# Add the mxfp8 casting CUDA extension
if use_cuda:
mxfp8_sources = [
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
]

# Only add the extension if the source files exist AND we are building for sm100
mxfp8_src_files_exist = all(os.path.exists(f) for f in mxfp8_sources)
if mxfp8_src_files_exist and build_for_sm100a:
print("Building mxfp8_cuda extension")
ext_modules.append(
CUDAExtension(
name="torchao.prototype.mxfp8_cuda",
sources=mxfp8_sources,
include_dirs=[
mxfp8_extension_dir, # For mxfp8_quantize.cuh, mxfp8_extension.cpp, and mxfp8_cuda.cu
"/usr/local/cuda-12.8/include", # CUDA 12.8 headers
],
library_dirs=[
"/usr/local/cuda-12.8/lib64", # CUDA 12.8 libraries
],
extra_compile_args={
"cxx": ["-std=c++17", "-O3"],
"nvcc": nvcc_args,
},
extra_link_args=["-lcuda", "-lcudart"],
),
)

# Only build the cutlass_90a extension if sm90a is in the architecture flags
if (
cutlass_90a_sources is not None
Expand Down
107 changes: 106 additions & 1 deletion test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
triton_to_mxfp8_dim1_reference,
unpack_uint4,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.mx_tensor import MXTensor, ScaleCalculationMode, to_mx
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
Expand All @@ -56,6 +56,15 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


# TODO: shared utils file for benchmarking and testing
def to_mx_dim1_reference(x_hp, block_size, scaling_mode):
x_hp = x_hp.t().contiguous()
scale_d1, data_d1 = to_mx(
x_hp, torch.float8_e4m3fn, block_size, scaling_mode=scaling_mode
)
return data_d1.t(), scale_d1


@pytest.mark.skip(
reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501
)
Expand Down Expand Up @@ -488,3 +497,99 @@ def test_rearrange(shape):
eager = to_blocked(scales, False)
triton = to_blocked(scales, True)
torch.testing.assert_close(eager, triton, atol=0, rtol=0)


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("M", (32, 64, 2048))
@pytest.mark.parametrize("K", (32, 64, 2048))
@pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16))
@pytest.mark.parametrize(
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
from torchao.prototype import mxfp8_cuda

scaling_mode_str = (
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
)
block_size = 32

# Use disinct incrementing values from 0 to M*K-1 to make debugging easier.
x = (
torch.arange(0, M * K, dtype=input_dtype, device="cuda")
.reshape(M, K)
.contiguous()
)

y_d1_ref, s_d1_ref = to_mx_dim1_reference(
x,
block_size=block_size,
scaling_mode=scaling_mode,
)

_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x,
rowwise=False,
colwise=True,
scaling_mode=scaling_mode_str,
scale_dim_x=1,
scale_dim_y=block_size,
)

# check scales
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)

# check quantized values
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also test the memory layout of all the tensors vs reference

assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
def test_cuda_mx_dim0_not_supported():
from torchao.prototype import mxfp8_cuda

M, K = 64, 64
block_size = 32
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
.reshape(M, K)
.contiguous()
)
with pytest.raises(RuntimeError):
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x,
rowwise=True,
colwise=False,
scale_dim_x=block_size,
scale_dim_y=1,
)


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
def test_cuda_mx_dim1_invalid_block_size():
from torchao.prototype import mxfp8_cuda

M, K = 64, 64
x = (
torch.arange(0, M * K, dtype=torch.bfloat16, device="cuda")
.reshape(M, K)
.contiguous()
)
invalid_block_size = 4
with pytest.raises(RuntimeError):
_, y_d1, _, s_d1 = mxfp8_cuda.quantize(
x,
rowwise=False,
colwise=True,
scale_dim_x=1,
scale_dim_y=invalid_block_size,
)
Loading
Loading