Skip to content

Commit 23c3162

Browse files
authored
add mxfp8_cublas recipe to mx_formats (#1831)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 8641fd6 commit 23c3162

File tree

5 files changed

+60
-17
lines changed

5 files changed

+60
-17
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import torch
1212
import torch.nn as nn
1313

14-
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
14+
from torchao.prototype.mx_formats.config import (
15+
MXLinearConfig,
16+
MXLinearRecipeName,
17+
)
1518
from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES
1619
from torchao.prototype.mx_formats.mx_linear import (
1720
MXInferenceLinear,
@@ -98,9 +101,16 @@ def test_linear_eager(elem_dtype, bias, input_shape):
98101
@pytest.mark.skipif(
99102
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
100103
)
101-
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4])
104+
@pytest.mark.parametrize(
105+
"recipe_name",
106+
[
107+
MXLinearRecipeName.MXFP8_CUBLAS,
108+
MXLinearRecipeName.MXFP8_CUTLASS,
109+
MXLinearRecipeName.MXFP4_CUTLASS,
110+
],
111+
)
102112
@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)])
103-
def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
113+
def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
104114
M, K, N = 128, 128, 128
105115
M, K, N = mkn
106116

@@ -112,12 +122,12 @@ def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
112122
)
113123
m_real = copy.deepcopy(m_emulated)
114124

125+
elem_dtype = torch.float8_e4m3fn
126+
if recipe_name == MXLinearRecipeName.MXFP4_CUTLASS:
127+
elem_dtype = DTYPE_FP4
128+
115129
config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
116-
config_real = MXLinearConfig(
117-
block_size=32,
118-
elem_dtype=elem_dtype,
119-
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
120-
)
130+
config_real = MXLinearConfig.from_recipe_name(recipe_name)
121131

122132
swap_linear_with_mx_linear(m_emulated, config=config_emulated)
123133
swap_linear_with_mx_linear(m_real, config=config_real)

torchao/prototype/mx_formats/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@ This is a module to do MX training, the MX matmul is currently emulated.
4242
```python
4343
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
4444
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
45-
from torchao.utils import is_sm_at_least_100
4645

4746
# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
4847
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
4948
gemm_kernel_choice = MXGemmKernelChoice.EMULATED
50-
if is_sm_at_least_100():
51-
gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
49+
50+
# on NVIDIA Blackwell GPUs, you can also use cuBLAS or CUTLASS mxfp8 kernels
51+
# note: torch.compile support for both of these is WIP
52+
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
53+
# gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
5254

5355
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
5456
config = MXLinearConfig(

torchao/prototype/mx_formats/config.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@ class MXGemmKernelChoice(Enum):
2424
# available only when CUDA capability is greater than or equal to 10.0
2525
CUTLASS = "cutlass"
2626

27-
# TODO(future PR): add cuBLAS here once we land pytorch/pytorch support
27+
# available only when CUDA capability is greater than or equal to 10.0
28+
# available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548
29+
# note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873
30+
CUBLAS = "cublas"
2831

2932

3033
# Pre-made recipes for common configurations
3134
class MXLinearRecipeName(Enum):
3235
MXFP8_EMULATED = "mxfp8_emulated"
36+
MXFP8_CUBLAS = "mxfp8_cublas"
3337
MXFP8_CUTLASS = "mxfp8_cutlass"
3438
MXFP4_EMULATED = "mxfp4_emulated"
3539
MXFP4_CUTLASS = "mxfp4_cutlass"
@@ -86,6 +90,20 @@ def __post_init__(self):
8690
assert (
8791
self.elem_dtype_grad_output_override is None
8892
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
93+
elif self.gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
94+
assert (
95+
self.block_size == 32
96+
), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {self.block_size}"
97+
valid_dtypes = [torch.float8_e4m3fn]
98+
assert (
99+
self.elem_dtype in valid_dtypes
100+
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
101+
assert (
102+
self.elem_dtype_weight_override is None
103+
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
104+
assert (
105+
self.elem_dtype_grad_output_override is None
106+
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
89107

90108
@staticmethod
91109
def from_recipe_name(
@@ -104,11 +122,13 @@ def from_recipe_name(
104122

105123
if recipe_name is MXLinearRecipeName.MXFP8_EMULATED:
106124
return MXLinearConfig()
125+
elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS:
126+
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUBLAS)
107127
elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS:
108128
return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS)
109129
elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED:
110130
return MXLinearConfig(elem_dtype=DTYPE_FP4)
111-
elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS:
131+
elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS:
112132
return MXLinearConfig(
113133
elem_dtype=DTYPE_FP4, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS
114134
)

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def mx_mm(aten_op, args, kwargs=None):
7070
b = args[1]
7171
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
7272
assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported"
73-
if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
73+
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
7474
# real MX gemm backed by torchao's CUTLASS kernels
7575
M, K, N = a.shape[0], a.shape[1], b.shape[1]
7676
assert b._data.t().is_contiguous()
@@ -81,12 +81,22 @@ def mx_mm(aten_op, args, kwargs=None):
8181
b_scale_block = to_blocked(b_scale)
8282
if a._elem_dtype == torch.float8_e4m3fn:
8383
assert b._elem_dtype == torch.float8_e4m3fn
84-
res = torchao.ops.mx_fp8_bf16(
85-
a._data, b._data, a_scale_block, b_scale_block
86-
)
84+
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
85+
res = torch._scaled_mm(
86+
a._data,
87+
b._data,
88+
a_scale_block.view(torch.float8_e8m0fnu),
89+
b_scale_block.view(torch.float8_e8m0fnu),
90+
out_dtype=torch.bfloat16,
91+
)
92+
else:
93+
res = torchao.ops.mx_fp8_bf16(
94+
a._data, b._data, a_scale_block, b_scale_block
95+
)
8796
else:
8897
assert a._elem_dtype == DTYPE_FP4
8998
assert b._elem_dtype == DTYPE_FP4
99+
assert a._gemm_kernel_choice is MXGemmKernelChoice.CUTLASS, "unsupported"
90100
res = torchao.ops.mx_fp4_bf16(
91101
a._data, b._data, a_scale_block, b_scale_block
92102
)

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def _torch_version_at_least(min_version):
614614
# | MI300X | gfx940, gfx941, gfx942 |
615615
# | MI350 | gfx950 |
616616

617+
617618
def is_ROCM():
618619
return torch.cuda.is_available() and torch.version.hip
619620

0 commit comments

Comments
 (0)