Skip to content

Commit d506cc7

Browse files
authored
Revert "Build mxfp4 kernel for sm120a" (#2428)
Revert "Build mxfp4 kernel for sm120a (#2285)" This reverts commit e73a142.
1 parent 4e25496 commit d506cc7

File tree

8 files changed

+50
-320
lines changed

8 files changed

+50
-320
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
get_name_to_shapes_iter,
1717
)
1818

19-
from torchao.ops import mx_fp4_bf16
20-
from torchao.prototype.mx_formats.mx_tensor import to_mx
2119
from torchao.testing.float8.roofline_utils import get_specs
2220

2321

@@ -64,19 +62,13 @@ def run(
6462
):
6563
device = "cuda"
6664
# TODO(future PR): this is ugly
67-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas", "mxfp4_cutlass"), (
68-
"unsupported"
69-
)
70-
use_fp4 = recipe == "mxfp4_cutlass"
65+
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
7166

7267
specs = get_specs()
7368
bf16_peak_tops = specs["bf16_peak_tops"]
7469
fp8_peak_tops = specs["fp8_peak_tops"]
75-
fp4_peak_tops = specs["fp4_peak_tops"]
7670
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
77-
print(
78-
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
79-
)
71+
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
8072

8173
headers = (
8274
"fast_accum",
@@ -85,14 +77,14 @@ def run(
8577
"K",
8678
"N",
8779
"ref_time_s",
88-
"time_s",
89-
"speedup",
80+
"fp8_time_s",
81+
"fp8_speedup",
9082
)
9183
results = []
9284

9385
dtype = torch.bfloat16
9486
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
95-
fast_accum_vals = [False] if use_fp4 else [True, False]
87+
fast_accum_vals = [True, False]
9688

9789
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
9890
itertools.product(fast_accum_vals, name_to_shapes)
@@ -115,53 +107,35 @@ def run(
115107

116108
del A
117109

118-
A_hp = torch.randn(M, K, device=device)
119-
B_hp_t = torch.randn(N, K, device=device)
120-
121-
if use_fp4:
122-
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
123-
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
124-
B = Bt.contiguous().T
125-
peak_tops = fp4_peak_tops
126-
else:
127-
# raw float8 matmul (upper bound for what we can achive in eager mode)
128-
# TODO(future): add e5m2
129-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
130-
A = A_hp.to(d1)
131-
B = B_hp_t.to(d2).contiguous().T
132-
peak_tops = fp8_peak_tops
133-
110+
# raw float8 matmul (upper bound for what we can achive in eager mode)
111+
# TODO(future): add e5m2
112+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
113+
A = torch.zeros(M, K, device=device, dtype=d1)
114+
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
134115
if recipe == "tensorwise":
135116
scale_a = torch.tensor([1.0], device=device)
136117
scale_b = torch.tensor([1.0], device=device)
137118
elif recipe == "rowwise":
138119
scale_a = torch.ones(M, 1, device=device)
139120
scale_b = torch.ones(1, N, device=device)
140-
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
121+
elif recipe == "mxfp8_cublas":
141122
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
142123
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
143124
else:
144125
assert False, f"unknown recipe {recipe}"
145126

146-
def do_matmul_fp8(A, B):
127+
def do_matmul(A, B):
147128
nonlocal scale_a
148129
nonlocal scale_b
149130
return torch._scaled_mm(
150131
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
151132
)
152133

153-
def do_matmul_mxfp4(A, B):
154-
nonlocal scale_a
155-
nonlocal scale_b
156-
return mx_fp4_bf16(A, B, scale_a, scale_b)
157-
158-
do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
159-
160-
time_sec, tops_sec, pct_top_peak = do_benchmarks(
161-
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
134+
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
135+
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
162136
)
163137
print(
164-
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
138+
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
165139
)
166140

167141
del A, B, scale_a, scale_b
@@ -174,8 +148,8 @@ def do_matmul_mxfp4(A, B):
174148
K,
175149
N,
176150
ref_time_sec,
177-
time_sec,
178-
ref_time_sec / time_sec,
151+
fp8_time_sec,
152+
ref_time_sec / fp8_time_sec,
179153
]
180154
)
181155

benchmarks/float8/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,9 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
352352
)
353353
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
354354
assert len(data) == 1
355-
key, value = next(iter(data.items()))
356-
assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16")
357-
return value / 1e6 / n_iter
355+
if "aten::mm" in data:
356+
return data["aten::mm"] / 1e6 / n_iter
357+
elif "aten::_scaled_mm" in data:
358+
return data["aten::_scaled_mm"] / 1e6 / n_iter
359+
else:
360+
raise AssertionError("unexpected format of data")

setup.py

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -272,18 +272,15 @@ def get_cutlass_build_flags():
272272
raise ValueError("No CUDA version found")
273273

274274
major, minor = map(int, cuda_version.split(".")[:2])
275-
build_sm90a = (major, minor) >= (12, 6)
276-
build_sm100a = (major, minor) >= (12, 8)
277-
build_sm120a = (major, minor) >= (12, 8)
275+
build_sm90a = major > 12 or (major == 12 and minor >= 6)
276+
build_sm100a = major > 12 or (major == 12 and minor >= 8)
278277

279278
if build_sm90a:
280279
print(f"CUDA {cuda_version}: Enabling SM90a CUTLASS kernels")
281280
if build_sm100a:
282281
print(f"CUDA {cuda_version}: Enabling SM100a CUTLASS kernels")
283-
if build_sm120a:
284-
print(f"CUDA {cuda_version}: Enabling SM120a CUTLASS kernels")
285282

286-
return build_sm90a, build_sm100a, build_sm120a
283+
return build_sm90a, build_sm100a
287284
except:
288285
# Fallback to architecture flags
289286
cuda_arch_flags = _get_cuda_arch_flags()
@@ -343,11 +340,6 @@ def __init__(
343340
self.cmake_args = cmake_args
344341

345342

346-
def remove_items(a: list, b: list) -> list:
347-
"""Remove items in list b from list a"""
348-
return [x for x in a if x not in b]
349-
350-
351343
def get_extensions():
352344
# Skip building C++ extensions if USE_CPP is set to "0"
353345
if use_cpp == "0":
@@ -462,7 +454,7 @@ def get_extensions():
462454
excluded_sources = list(
463455
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
464456
)
465-
sources = remove_items(sources, excluded_sources)
457+
sources = [s for s in sources if s not in excluded_sources]
466458

467459
# Collect CUDA source files
468460
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
@@ -506,24 +498,22 @@ def get_extensions():
506498
rocm_sources = list(
507499
glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True)
508500
)
509-
sources = remove_items(sources, rocm_sources)
501+
sources = [s for s in sources if s not in rocm_sources]
510502

511-
use_cutlass = use_cuda and not IS_WINDOWS
503+
use_cutlass = False
512504
cutlass_90a_sources = None
513505
cutlass_100a_sources = None
514-
cutlass_120a_sources = None
515506
build_for_sm90a = False
516507
build_for_sm100a = False
517-
build_for_sm120a = False
518-
519-
if use_cutlass:
508+
if use_cuda and not IS_WINDOWS:
509+
use_cutlass = True
520510
cutlass_dir = os.path.join(third_party_path, "cutlass")
521511
cutlass_include_dir = os.path.join(cutlass_dir, "include")
522512
cutlass_tools_include_dir = os.path.join(
523513
cutlass_dir, "tools", "util", "include"
524514
)
525515
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
526-
516+
if use_cutlass:
527517
extra_compile_args["nvcc"].extend(
528518
[
529519
"-DTORCHAO_USE_CUTLASS",
@@ -543,7 +533,7 @@ def get_extensions():
543533
]
544534
)
545535

546-
build_for_sm90a, build_for_sm100a, build_for_sm120a = get_cutlass_build_flags()
536+
build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
547537
# Define sm90a sources
548538
cutlass_90a_sources = [
549539
os.path.join(
@@ -567,40 +557,40 @@ def get_extensions():
567557
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
568558
)
569559
)
570-
sources = remove_items(sources, cutlass_90a_sources)
560+
# Always remove sm90a sources from main sources
561+
sources = [s for s in sources if s not in cutlass_90a_sources]
571562

572563
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
573564
cutlass_100a_sources = [
574565
os.path.join(
575566
extensions_cuda_dir,
576567
"mx_kernels",
577-
"mx_fp_cutlass_kernels_sm100a.cu",
568+
"mx_fp_cutlass_kernels.cu",
578569
),
579570
]
580-
sources = remove_items(sources, cutlass_100a_sources)
581-
582-
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583-
cutlass_120a_sources = [
584-
os.path.join(
585-
extensions_cuda_dir,
586-
"mx_kernels",
587-
"mx_fp_cutlass_kernels_sm120a.cu",
588-
),
571+
# Remove from main sources to prevent compilation with other architectures
572+
sources = [
573+
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
589574
]
590-
sources = remove_items(sources, cutlass_120a_sources)
591575

592576
else:
593-
# Remove CUTLASS-based kernels from the sources list. An assumption is that
594-
# these files will have "cutlass" in its name.
577+
# Remove CUTLASS-based kernels from the sources list. An
578+
# assumption is that these files will have "cutlass" in its
579+
# name.
595580
cutlass_sources = list(
596581
glob.glob(
597582
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
598583
)
599584
)
600-
sources = remove_items(sources, cutlass_sources)
585+
sources = [s for s in sources if s not in cutlass_sources]
601586

602587
ext_modules = []
603588
if len(sources) > 0:
589+
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590+
sources = [
591+
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
592+
]
593+
604594
ext_modules.append(
605595
extension(
606596
"torchao._C",
@@ -653,27 +643,6 @@ def get_extensions():
653643
)
654644
)
655645

656-
# Only build the cutlass_120a extension if sm120a is in the architecture flags
657-
if (
658-
cutlass_120a_sources is not None
659-
and len(cutlass_120a_sources) > 0
660-
and build_for_sm120a
661-
):
662-
cutlass_120a_extra_compile_args = copy.deepcopy(extra_compile_args)
663-
# Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664-
cutlass_120a_extra_compile_args["nvcc"].append(
665-
"-gencode=arch=compute_120a,code=sm_120a"
666-
)
667-
ext_modules.append(
668-
extension(
669-
"torchao._C_cutlass_120a",
670-
cutlass_120a_sources,
671-
py_limited_api=True,
672-
extra_compile_args=cutlass_120a_extra_compile_args,
673-
extra_link_args=extra_link_args,
674-
)
675-
)
676-
677646
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
678647
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
679648
build_options = BuildOptions()

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchao.prototype.mx_formats.utils import to_blocked
1515
from torchao.utils import (
1616
TORCH_VERSION_AT_LEAST_2_8,
17-
is_sm_version,
17+
is_sm_at_least_100,
1818
)
1919

2020
if not TORCH_VERSION_AT_LEAST_2_8:
@@ -59,8 +59,7 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
5959

6060
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6161
@pytest.mark.skipif(
62-
not (is_sm_version(10, 0) or is_sm_version(12, 0)),
63-
reason="CUDA capability 10.0 or 12.0 is required for mxfloat8",
62+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
6463
)
6564
@pytest.mark.parametrize(
6665
"size",

torchao/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,8 @@
2525

2626
so_files = list(Path(__file__).parent.glob("_C*.so"))
2727
if len(so_files) > 0:
28-
compute_capability = (
29-
torch.cuda.get_device_capability() if torch.cuda.is_available() else None
30-
)
31-
3228
for file in so_files:
33-
# only load architecture-specific target if the current GPU matches that target
34-
if (
35-
("cutlass_90a" in file.name and compute_capability != (9, 0))
36-
or ("cutlass_100a" in file.name and compute_capability != (10, 0))
37-
or ("cutlass_120a" in file.name and compute_capability != (12, 0))
38-
):
39-
continue
40-
4129
torch.ops.load_library(str(file))
42-
4330
from . import ops
4431

4532
# The following library contains CPU kernels from torchao/experimental

0 commit comments

Comments
 (0)