Skip to content

Commit 9b1256f

Browse files
authored
2:4 activation sparsity packing kernels (#2012)
This PR is meant to give users the ability to accelerate LLMs with 2:4 activation sparsity, using the approach outlined in our ICLR workshop paper: https://arxiv.org/abs/2503.16672 The main contribution is a cutlass 24_fp8_pack kernel that is able to relatively efficiently calculate the packed data and metadata given a normal dense tensor, which I've copied over from xFormers. ### Performance Benchmarks ``` python benchmarks/benchmark_e2e_fp8_sparse_linear.py | num_tokens | bf16_latency (us) | bf16_c_latency (us) | fp8_c_time (us) | fp8_c_sparse_time (us) | fp8_c_activation_sparse_time (us) | speedup | |-------------:|--------------------:|----------------------:|------------------:|-------------------------:|------------------------------------:|----------:| | 64 | 166.816 | 163.04 | 103.008 | 74.304 | 102.816 | 1.00187 | | 128 | 156.256 | 151.52 | 99.936 | 75.456 | 102.048 | 0.979304 | | 256 | 172.288 | 159.584 | 114.08 | 82.432 | 111.072 | 1.02708 | | 512 | 218.88 | 204.608 | 144.096 | 114.56 | 139.488 | 1.03304 | | 1024 | 394.4 | 392.544 | 251.104 | 196.416 | 227.904 | 1.1018 | | 2048 | 764.608 | 734.816 | 480.704 | 381.152 | 426.688 | 1.12659 | | 4096 | 1658.82 | 1623.58 | 901.344 | 779.008 | 843.392 | 1.06871 | ``` ### Tests ``` pytest tests/sparsity/test_activation24.py ```
1 parent 66eb801 commit 9b1256f

18 files changed

+1906
-10
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import pandas as pd
7+
import torch
8+
from torch import nn
9+
from tqdm import tqdm
10+
from triton.testing import do_bench
11+
12+
from torchao.prototype.sparsity.activation.srelu_linear import (
13+
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig,
14+
)
15+
from torchao.prototype.sparsity.activation.utils import SquaredReLU
16+
from torchao.quantization import (
17+
Float8DynamicActivationFloat8SemiSparseWeightConfig,
18+
Float8DynamicActivationFloat8WeightConfig,
19+
Float8MMConfig,
20+
PerRow,
21+
quantize_,
22+
)
23+
24+
25+
def benchmark_microseconds(f, *args):
26+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
27+
28+
29+
def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
30+
ffn_ref = (
31+
nn.Sequential(
32+
nn.Linear(hidden_size, intermediate_size, bias=False),
33+
SquaredReLU(),
34+
nn.Linear(intermediate_size, hidden_size, bias=False),
35+
)
36+
.to(torch.bfloat16)
37+
.cuda()
38+
)
39+
40+
input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda()
41+
fp16_time = benchmark_microseconds(ffn_ref, input_tensor)
42+
43+
# bf16
44+
ffn_clone = (
45+
nn.Sequential(
46+
nn.Linear(hidden_size, intermediate_size, bias=False),
47+
SquaredReLU(),
48+
nn.Linear(intermediate_size, hidden_size, bias=False),
49+
)
50+
.to(torch.bfloat16)
51+
.cuda()
52+
)
53+
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
54+
fp16_c_time = benchmark_microseconds(ffn_clone, input_tensor)
55+
56+
# fp8
57+
ffn_clone = (
58+
nn.Sequential(
59+
nn.Linear(hidden_size, intermediate_size, bias=False),
60+
SquaredReLU(),
61+
nn.Linear(intermediate_size, hidden_size, bias=False),
62+
)
63+
.to(torch.bfloat16)
64+
.cuda()
65+
)
66+
quantize_(
67+
ffn_clone,
68+
Float8DynamicActivationFloat8WeightConfig(
69+
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
70+
),
71+
)
72+
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
73+
fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor)
74+
75+
# fp8 sparse
76+
ffn_clone = (
77+
nn.Sequential(
78+
nn.Linear(hidden_size, intermediate_size, bias=False),
79+
SquaredReLU(),
80+
nn.Linear(intermediate_size, hidden_size, bias=False),
81+
)
82+
.to(torch.bfloat16)
83+
.cuda()
84+
)
85+
quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig())
86+
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
87+
fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)
88+
89+
# activation fp8 sparse
90+
ffn_clone = (
91+
nn.Sequential(
92+
nn.Linear(hidden_size, intermediate_size, bias=False),
93+
# no Squared RELU since it will be fused into the second linear
94+
nn.Linear(intermediate_size, hidden_size, bias=False),
95+
)
96+
.to(torch.bfloat16)
97+
.cuda()
98+
)
99+
quantize_(
100+
ffn_clone[0],
101+
Float8DynamicActivationFloat8WeightConfig(
102+
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
103+
),
104+
)
105+
quantize_(
106+
ffn_clone,
107+
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(),
108+
filter_fn=lambda mod, fqn: "1" in fqn,
109+
)
110+
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
111+
fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)
112+
113+
return {
114+
"num_tokens": num_tokens,
115+
"bf16_latency (us)": fp16_time,
116+
"bf16_c_latency (us)": fp16_c_time,
117+
"fp8_c_time (us)": fp8_c_time,
118+
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
119+
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
120+
"speedup": fp8_c_time / fp8_c_activation_sparse_time,
121+
}
122+
123+
124+
if __name__ == "__main__":
125+
with torch.no_grad():
126+
results = []
127+
for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]):
128+
results.append(benchmark(num_tokens))
129+
torch.compiler.reset()
130+
131+
df = pd.DataFrame(results)
132+
df.to_csv("e2e_fp8_sparse.csv", index=False)
133+
print(df.to_markdown(index=False))

benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torchao.sparsity.utils import create_semi_structured_tensor
1717

1818
dtype = torch.bfloat16
19-
dtypeq_X = torch.float8_e5m2
19+
dtypeq_X = torch.float8_e4m3fn
2020
dtypeq_W = torch.float8_e4m3fn
2121
device = torch.device("cuda")
2222

@@ -25,7 +25,7 @@ def benchmark_microseconds(f, *args):
2525
return do_bench(lambda: f(*args), return_mode="median") * 1e3
2626

2727

28-
def get_problem(m: int, n: int, k: int):
28+
def get_problem_cutlass(m: int, n: int, k: int):
2929
X_ref = torch.randn((m, k), dtype=dtype, device=device)
3030
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
3131

@@ -45,30 +45,68 @@ def get_problem(m: int, n: int, k: int):
4545
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)
4646

4747

48+
def get_problem_cusparselt(m: int, n: int, k: int):
49+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
50+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
51+
52+
Xq = X_ref.to(dtypeq_W)
53+
Wq = W_ref.to(dtypeq_W)
54+
55+
Wqs = torch._cslt_compress(Wq)
56+
57+
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(
58+
Wqs, Xq.t(), None, None, None, False
59+
)
60+
61+
return (Wqs, Xq.t(), None, None, dtype, False, alg_id, split_k, split_k_one_kernel)
62+
63+
64+
def get_problem_scaled_mm(m: int, n: int, k: int):
65+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
66+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
67+
68+
X_aqt = _float8_cutlass_quant(X_ref, dtypeq_W)
69+
W_aqt = _float8_cutlass_quant(W_ref, dtypeq_W)
70+
71+
Xq = X_aqt.tensor_impl.float8_data
72+
Wq = W_aqt.tensor_impl.float8_data
73+
X_scale = X_aqt.tensor_impl.scale.unsqueeze(0)
74+
W_scale = W_aqt.tensor_impl.scale.unsqueeze(-1)
75+
76+
return (Wq, Xq.t(), W_scale, X_scale, None, None, dtype)
77+
78+
4879
def benchmark(m: int, k: int, n: int):
49-
ref_args, args = get_problem(m, n, k)
80+
ref_args, args = get_problem_cutlass(m, n, k)
5081
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
5182
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
5283
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
5384
)
5485

86+
cslt_args = get_problem_cusparselt(m, n, k)
87+
cusparselt_time = benchmark_microseconds(torch._cslt_sparse_mm, *cslt_args)
88+
89+
fp8_args = get_problem_scaled_mm(m, n, k)
90+
fp8_time = benchmark_microseconds(torch._scaled_mm, *fp8_args)
91+
5592
return {
5693
"m": m,
5794
"k": k,
5895
"n": n,
5996
"fp16_latency (ms)": fp16_time,
97+
"fp8_latency (ms)": fp8_time,
6098
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
61-
"f8f8 speedup (d/s)": fp16_time
62-
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
99+
"cusparselt latency (ms)": cusparselt_time,
100+
"f8f8 speedup (d/s)": fp8_time / rowwise_scaled_linear_sparse_cutlass_f8f8_time,
63101
}
64102

65103

66104
if __name__ == "__main__":
67-
k_vals = (8192, 8192, 8192, 28672)
68-
n_vals = (8192, 10240, 57344, 8192)
105+
k_vals = (8192,)
106+
n_vals = (8192,)
69107

70108
results = []
71-
for m in tqdm([1 << i for i in range(10)]):
109+
for m in tqdm([2048, 4096, 8192]):
72110
for n, k in zip(n_vals, k_vals):
73111
results.append(benchmark(m, k, n))
74112

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import pandas as pd
7+
import torch
8+
from triton.testing import do_bench
9+
10+
from torchao.ops import (
11+
to_sparse_semi_structured_cutlass_sm9x_f8,
12+
)
13+
from torchao.quantization.quant_api import (
14+
_float8_cutlass_quant,
15+
_float8_cutlass_quant_sparse,
16+
)
17+
from torchao.sparsity.utils import create_semi_structured_tensor
18+
19+
dtype = torch.bfloat16
20+
dtypeq_X = torch.float8_e4m3fn
21+
dtypeq_W = torch.float8_e4m3fn
22+
device = torch.device("cuda")
23+
24+
25+
def benchmark_microseconds(f, *args):
26+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
27+
28+
29+
def get_problem_cutlass(m: int, n: int, k: int):
30+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
31+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
32+
33+
X_quant_func = _float8_cutlass_quant
34+
W_quant_func = _float8_cutlass_quant_sparse
35+
X_aqt = X_quant_func(X_ref, dtypeq_X)
36+
W_aqt = W_quant_func(W_ref, dtypeq_W)
37+
38+
Xq = X_aqt.tensor_impl.float8_data
39+
X_scale = X_aqt.tensor_impl.scale
40+
Wq_sparse = W_aqt.tensor_impl.sparse
41+
W_meta = W_aqt.tensor_impl.meta
42+
W_scale = W_aqt.tensor_impl.scale
43+
bias = None
44+
out_dtype = dtype
45+
46+
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)
47+
48+
49+
def get_problem_cusparselt(m: int, n: int, k: int):
50+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
51+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
52+
53+
Xq = X_ref.to(dtypeq_W)
54+
Wq = W_ref.to(dtypeq_W)
55+
56+
Wqs = torch._cslt_compress(Wq)
57+
58+
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(
59+
Wqs, Xq.t(), None, None, None, False
60+
)
61+
62+
return (Wqs, Xq.t(), None, None, dtype, False, alg_id, split_k, split_k_one_kernel)
63+
64+
65+
def get_problem_scaled_mm(m: int, n: int, k: int):
66+
X_ref = torch.randn((m, k), dtype=dtype, device=device)
67+
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)
68+
69+
X_aqt = _float8_cutlass_quant(X_ref, dtypeq_W)
70+
W_aqt = _float8_cutlass_quant(W_ref, dtypeq_W)
71+
72+
Xq = X_aqt.tensor_impl.float8_data
73+
Wq = W_aqt.tensor_impl.float8_data
74+
X_scale = X_aqt.tensor_impl.scale.unsqueeze(0)
75+
W_scale = W_aqt.tensor_impl.scale.unsqueeze(-1)
76+
77+
return (Wq, Xq.t(), W_scale, X_scale, None, None, dtype)
78+
79+
80+
def benchmark(m, k):
81+
torch.manual_seed(123)
82+
W_ref = create_semi_structured_tensor(m, k, dtype=torch.float8_e4m3fn).cuda()
83+
84+
# packed, meta = torch.ops.torchao.sparse_semi_structured_tile.default(W_ref, "", True)
85+
cutlass_reference_args = (W_ref,)
86+
cutlass_custom_args = (W_ref, "", True)
87+
88+
cutlass_reference_compression_time = benchmark_microseconds(
89+
to_sparse_semi_structured_cutlass_sm9x_f8, *cutlass_reference_args
90+
)
91+
cutlass_custom_compression_time = benchmark_microseconds(
92+
torch.ops.torchao.sparse_semi_structured_tile.default, *cutlass_custom_args
93+
)
94+
95+
return {
96+
"cutlass_reference (ms)": cutlass_reference_compression_time,
97+
"cutlass_custom (ms)": cutlass_custom_compression_time,
98+
}
99+
100+
101+
def profile():
102+
torch.manual_seed(123)
103+
W_ref = create_semi_structured_tensor(8192, 8192, dtype=torch.float8_e4m3fn).cuda()
104+
105+
# clear cache
106+
new_val = torch.empty(10000, 10000, device="cuda")
107+
new_val[:, :] = 0
108+
109+
packed, meta = torch.ops.torchao.sparse_semi_structured_tile.default(
110+
W_ref, "", True
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
results = []
116+
for m in (2048, 4096, 8192):
117+
results.append(benchmark(m, 8192))
118+
119+
df = pd.DataFrame(results)
120+
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
121+
print(df.to_markdown(index=False))
122+
123+
# print("PROFILING")
124+
# profile()

e2e_fp8_sparse.csv

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
num_tokens,bf16_latency (us),bf16_c_latency (us),fp8_c_time (us),fp8_c_sparse_time (us),fp8_c_activation_sparse_time (us),speedup
2+
64,166.81599617004395,163.03999722003937,103.00800204277039,74.30399954319,102.81600058078766,1.0018674278409796
3+
128,156.25600516796112,151.5199989080429,99.93600100278854,75.45600086450577,102.04800218343735,0.9793038458817415
4+
256,172.28800058364868,159.58400070667267,114.07999694347382,82.43200182914734,111.07199639081955,1.0270815385551393
5+
512,218.87999773025513,204.6079933643341,144.0960019826889,114.56000059843063,139.48799669742584,1.0330351384661336
6+
1024,394.4000005722046,392.5440013408661,251.10399723052979,196.4160054922104,227.90400683879852,1.1017972027501084
7+
2048,764.6080255508423,734.8160147666931,480.70400953292847,381.1520040035248,426.68798565864563,1.1265937305239622
8+
4096,1658.8159799575806,1623.5840320587158,901.3440012931824,779.0079712867737,843.392014503479,1.0687129896811043
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
m,k,n,fp16_latency (ms),fp8_latency (ms),rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms),cusparselt latency (ms),f8f8 speedup (d/s)
2+
2048,8192,8192,345.7919955253601,243.13600361347198,159.7760021686554,634.2080235481262,1.5217304245528933
3+
4096,8192,8192,756.3199996948242,500.2880096435547,363.647997379303,628.7999749183655,1.3757480124982768
4+
8192,8192,8192,1433.568000793457,982.5279712677002,895.3920006752014,859.935998916626,1.0973160029649482

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def get_extensions():
382382
"to_sparse_semi_structured_cutlass_sm9x",
383383
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
384384
),
385+
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
385386
]
386387
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
387388
cutlass_90a_sources.append(

0 commit comments

Comments
 (0)