Skip to content

Commit c376285

Browse files
alexsamardzicpytorchmergebot
authored andcommitted
Add CUTLASS-based row-wise scaled sparse FP8 kernel (#1671)
Pull Request resolved: #1671 Approved by: https://github.com/jcaip
1 parent 0eea64a commit c376285

32 files changed

+2177
-528
lines changed

benchmarks/benchmark_rowwise_scaled_linear_cutlass.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,41 +7,55 @@
77
rowwise_scaled_linear_cutlass_s4s4,
88
rowwise_scaled_linear_cutlass_s8s4,
99
)
10+
from torchao.quantization.quant_api import (
11+
_int4_symm_cutlass_quant,
12+
_int8_symm_cutlass_quant,
13+
)
14+
15+
dtype = torch.bfloat16
16+
dtypeq = torch.int8
17+
dtype_scale = torch.float32
18+
device = torch.device("cuda")
1019

1120

1221
def benchmark_microseconds(f, *args):
1322
return do_bench(lambda: f(*args), return_mode="median") * 1e3
1423

1524

16-
def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int):
17-
assert A_nbits in (4, 8) and B_nbits in (4, 8)
25+
def get_problem(m: int, n: int, k: int, Xq_nbits: int):
26+
assert k % 2 == 0
27+
assert Xq_nbits in [4, 8]
28+
29+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
30+
W_ref = torch.rand((n, k), dtype=dtype, device=device)
1831

19-
dev = torch.device("cuda")
20-
A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev)
21-
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
22-
B = torch.randint(
23-
-128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev
32+
X_quant_func = (
33+
_int4_symm_cutlass_quant if Xq_nbits == 4 else _int8_symm_cutlass_quant
2434
)
25-
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
26-
C = None
35+
W_quant_func = _int4_symm_cutlass_quant
36+
X_aqt = X_quant_func(X_ref)
37+
W_aqt = W_quant_func(W_ref)
2738

28-
return A, A_scale, B, B_scale, C
39+
Xq = X_aqt.tensor_impl.int_data
40+
X_scale = X_aqt.tensor_impl.scale
41+
Wq = W_aqt.tensor_impl.int_data
42+
W_scale = W_aqt.tensor_impl.scale
43+
bias = None
44+
out_dtype = dtype
2945

46+
return (X_ref, W_ref), (Xq, X_scale, Wq, W_scale, bias, out_dtype)
3047

31-
def benchmark(m: int, k: int, n: int):
32-
dev = torch.device("cuda")
33-
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
34-
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
35-
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)
3648

37-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4)
38-
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
39-
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
49+
def benchmark(m: int, k: int, n: int):
50+
ref_args, args = get_problem(m, n, k, 4)
51+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
52+
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
53+
rowwise_scaled_linear_cutlass_s4s4, *args
4054
)
4155

42-
A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4)
43-
rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds(
44-
rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C
56+
_, args = get_problem(m, n, k, 8)
57+
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
58+
rowwise_scaled_linear_cutlass_s8s4, *args
4559
)
4660

4761
return {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pandas as pd
2+
import torch
3+
from tqdm import tqdm
4+
from triton.testing import do_bench
5+
6+
from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8
7+
from torchao.quantization.quant_api import (
8+
_float8_cutlass_quant,
9+
_float8_cutlass_quant_sparse,
10+
)
11+
from torchao.sparsity.utils import create_semi_structured_tensor
12+
13+
dtype = torch.bfloat16
14+
dtypeq_X = torch.float8_e5m2
15+
dtypeq_W = torch.float8_e4m3fn
16+
device = torch.device("cuda")
17+
18+
19+
def benchmark_microseconds(f, *args):
20+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
21+
22+
23+
def get_problem(m: int, n: int, k: int):
24+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
25+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
26+
27+
X_quant_func = _float8_cutlass_quant
28+
W_quant_func = _float8_cutlass_quant_sparse
29+
X_aqt = X_quant_func(X_ref, dtypeq_X)
30+
W_aqt = W_quant_func(W_ref, dtypeq_W)
31+
32+
Xq = X_aqt.tensor_impl.float8_data
33+
X_scale = X_aqt.tensor_impl.scale
34+
Wq_sparse = W_aqt.tensor_impl.sparse
35+
W_meta = W_aqt.tensor_impl.meta
36+
W_scale = W_aqt.tensor_impl.scale
37+
bias = None
38+
out_dtype = dtype
39+
40+
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)
41+
42+
43+
def benchmark(m: int, k: int, n: int):
44+
ref_args, args = get_problem(m, n, k)
45+
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
46+
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
47+
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
48+
)
49+
50+
return {
51+
"m": m,
52+
"k": k,
53+
"n": n,
54+
"fp16_latency (ms)": fp16_time,
55+
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
56+
"f8f8 speedup (d/s)": fp16_time
57+
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
58+
}
59+
60+
61+
if __name__ == "__main__":
62+
k_vals = (8192, 8192, 8192, 28672)
63+
n_vals = (8192, 10240, 57344, 8192)
64+
65+
results = []
66+
for m in tqdm([1 << i for i in range(10)]):
67+
for n, k in zip(n_vals, k_vals):
68+
results.append(benchmark(m, k, n))
69+
70+
df = pd.DataFrame(results)
71+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
72+
print(df.to_markdown(index=False))

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Layouts and Tensor Subclasses
2828
MarlinQQQLayout
2929
Int4CPULayout
3030
CutlassInt4PackedLayout
31+
CutlassSemiSparseLayout
3132

3233
Quantization techniques
3334
-----------------------

setup.py

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
67
import glob
78
import os
89
import subprocess
@@ -75,6 +76,7 @@ def use_debug_mode():
7576
BuildExtension,
7677
CppExtension,
7778
CUDAExtension,
79+
_get_cuda_arch_flags,
7880
)
7981

8082
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
@@ -269,7 +271,12 @@ def get_extensions():
269271
extra_link_args = []
270272
extra_compile_args = {
271273
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
272-
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
274+
"nvcc": [
275+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
276+
"-O3" if not debug_mode else "-O0",
277+
"-t=0",
278+
"-std=c++17",
279+
],
273280
}
274281

275282
if not IS_WINDOWS:
@@ -304,25 +311,6 @@ def get_extensions():
304311
if use_cuda:
305312
sources += cuda_sources
306313

307-
use_cutlass = False
308-
if use_cuda and not IS_WINDOWS:
309-
use_cutlass = True
310-
cutlass_dir = os.path.join(third_party_path, "cutlass")
311-
cutlass_include_dir = os.path.join(cutlass_dir, "include")
312-
cutlass_tools_include_dir = os.path.join(
313-
cutlass_dir, "tools", "util", "include"
314-
)
315-
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
316-
if use_cutlass:
317-
extra_compile_args["nvcc"].extend(
318-
[
319-
"-DTORCHAO_USE_CUTLASS",
320-
"-I" + cutlass_include_dir,
321-
"-I" + cutlass_tools_include_dir,
322-
"-I" + cutlass_extensions_include_dir,
323-
]
324-
)
325-
326314
# Get base directory and source paths
327315
curdir = os.path.dirname(os.path.curdir)
328316
extensions_dir = os.path.join(curdir, "torchao", "csrc")
@@ -349,16 +337,6 @@ def get_extensions():
349337
# Collect CUDA source files if needed
350338
if not IS_ROCM and use_cuda:
351339
sources += cuda_sources
352-
else:
353-
# Remove CUTLASS-based kernels from the cuda_sources list. An
354-
# assumption is that these files will have "cutlass" in its
355-
# name.
356-
cutlass_sources = list(
357-
glob.glob(
358-
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
359-
)
360-
)
361-
sources = [s for s in sources if s not in cutlass_sources]
362340

363341
# TOOD: Remove this and use what CUDA has once we fix all the builds.
364342
if IS_ROCM and use_cuda:
@@ -372,6 +350,72 @@ def get_extensions():
372350
else:
373351
sources += hip_sources
374352

353+
use_cutlass = False
354+
cutlass_90a_sources = None
355+
if use_cuda and not IS_ROCM and not IS_WINDOWS:
356+
use_cutlass = True
357+
cutlass_dir = os.path.join(third_party_path, "cutlass")
358+
cutlass_include_dir = os.path.join(cutlass_dir, "include")
359+
cutlass_tools_include_dir = os.path.join(
360+
cutlass_dir, "tools", "util", "include"
361+
)
362+
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
363+
if use_cutlass:
364+
extra_compile_args["nvcc"].extend(
365+
[
366+
"-DTORCHAO_USE_CUTLASS",
367+
"-I" + cutlass_include_dir,
368+
"-I" + cutlass_tools_include_dir,
369+
"-I" + cutlass_extensions_include_dir,
370+
"-DCUTE_USE_PACKED_TUPLE=1",
371+
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
372+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
373+
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
374+
"--ftemplate-backtrace-limit=0",
375+
# "--keep",
376+
# "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
377+
# "--resource-usage",
378+
# "-lineinfo",
379+
# "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
380+
]
381+
)
382+
383+
cuda_arch_flags = _get_cuda_arch_flags()
384+
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
385+
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
386+
if build_for_sm90 and not build_for_sm90a:
387+
cutlass_90a_sources = [
388+
os.path.join(
389+
extensions_cuda_dir,
390+
"rowwise_scaled_linear_sparse_cutlass",
391+
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
392+
),
393+
os.path.join(
394+
extensions_cuda_dir,
395+
"to_sparse_semi_structured_cutlass_sm9x",
396+
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
397+
),
398+
]
399+
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
400+
cutlass_90a_sources.append(
401+
os.path.join(
402+
extensions_cuda_dir,
403+
"rowwise_scaled_linear_sparse_cutlass",
404+
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
405+
)
406+
)
407+
sources = [s for s in sources if s not in cutlass_90a_sources]
408+
else:
409+
# Remove CUTLASS-based kernels from the sources list. An
410+
# assumption is that these files will have "cutlass" in its
411+
# name.
412+
cutlass_sources = list(
413+
glob.glob(
414+
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
415+
)
416+
)
417+
sources = [s for s in sources if s not in cutlass_sources]
418+
375419
ext_modules = []
376420
if len(sources) > 0:
377421
ext_modules.append(
@@ -384,6 +428,21 @@ def get_extensions():
384428
)
385429
)
386430

431+
if cutlass_90a_sources is not None and len(cutlass_90a_sources) > 0:
432+
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
433+
cutlass_90a_extra_compile_args["nvcc"].extend(
434+
cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a"]
435+
)
436+
ext_modules.append(
437+
extension(
438+
"torchao._C",
439+
cutlass_90a_sources,
440+
py_limited_api=True,
441+
extra_compile_args=cutlass_90a_extra_compile_args,
442+
extra_link_args=extra_link_args,
443+
)
444+
)
445+
387446
if build_torchao_experimental:
388447
build_options = BuildOptions()
389448

0 commit comments

Comments
 (0)