Skip to content

Commit 4671ac6

Browse files
authored
[Bugfix][Benchmark] Fix Marlin benchmark (#19929)
1 parent dd2ccf8 commit 4671ac6

File tree

1 file changed

+150
-79
lines changed

1 file changed

+150
-79
lines changed

benchmarks/kernels/benchmark_marlin.py

Lines changed: 150 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,16 @@
2222
MARLIN_SUPPORTED_GROUP_SIZES,
2323
query_marlin_supported_quant_types,
2424
)
25+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
26+
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
27+
rand_marlin_weight_fp4_like,
28+
)
29+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
30+
marlin_quant_fp8_torch,
31+
)
2532
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
2633
MarlinWorkspace,
34+
awq_marlin_quantize,
2735
marlin_quantize,
2836
)
2937
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
@@ -35,7 +43,7 @@
3543
quantize_weights,
3644
sort_weights,
3745
)
38-
from vllm.scalar_type import ScalarType
46+
from vllm.scalar_type import ScalarType, scalar_types
3947
from vllm.utils import FlexibleArgumentParser
4048

4149
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
@@ -57,80 +65,144 @@ def bench_run(
5765
size_n: int,
5866
):
5967
label = "Quant Matmul"
60-
6168
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
6269
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
6370
)
64-
6571
print(f"Testing: {sub_label}")
6672

6773
a = torch.randn(size_m, size_k).to(torch.half).cuda()
6874
b = torch.rand(size_k, size_n).to(torch.half).cuda()
75+
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
76+
if act_order and (group_size == -1 or group_size == size_k or has_zp):
77+
return
78+
if size_k % group_size != 0:
79+
return
6980

70-
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
81+
marlin_24_supported = (
82+
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
83+
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
84+
)
85+
repack_supported = (
86+
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
87+
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
88+
)
89+
allspark_supported = (
90+
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
91+
and group_size == -1
92+
and not act_order
93+
and is_k_full
94+
)
95+
96+
def gen_marlin_params():
97+
# Marlin quant
98+
marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
99+
if quant_type == scalar_types.float4_e2m1f:
100+
if group_size != 16 or act_order:
101+
return
102+
marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
103+
b.T, group_size
104+
)
105+
elif quant_type == scalar_types.float8_e4m3fn:
106+
if group_size not in [-1, 128] or act_order:
107+
return
108+
marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
109+
elif group_size == 16:
110+
return
111+
elif has_zp:
112+
marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
113+
b, quant_type, group_size
114+
)
115+
else:
116+
marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
117+
marlin_quantize(b, quant_type, group_size, act_order)
118+
)
119+
return (
120+
marlin_w_ref,
121+
marlin_q_w,
122+
marlin_s,
123+
marlin_s2,
124+
marlin_zp,
125+
marlin_g_idx,
126+
marlin_sort_indices,
127+
)
128+
129+
def gen_marlin_24_params():
130+
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
131+
if marlin_24_supported:
132+
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
133+
marlin_24_quantize(b, quant_type, group_size)
134+
)
135+
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
136+
137+
def gen_repack_params():
138+
q_w_gptq = None
139+
repack_sort_indices = None
140+
if repack_supported:
141+
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
142+
b, quant_type, group_size, act_order
143+
)
144+
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
145+
146+
# For act_order, sort the "weights" and "g_idx"
147+
# so that group ids are increasing
148+
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
149+
if act_order:
150+
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
151+
return q_w_gptq, repack_sort_indices
152+
153+
def gen_allspark_params():
154+
qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
155+
CUBLAS_M_THRESHOLD
156+
) = None
157+
nonlocal allspark_supported
158+
if allspark_supported:
159+
properties = torch.cuda.get_device_properties(b.device.index)
160+
sm_count = properties.multi_processor_count
161+
sm_version = properties.major * 10 + properties.minor
162+
163+
supported_arch = sm_version >= 80 and sm_version < 90
164+
allspark_supported = allspark_supported and supported_arch
165+
if supported_arch:
166+
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
167+
qw = qw.to(torch.uint8)
168+
169+
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
170+
qw, s, zp, has_zp
171+
)
172+
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
173+
return (
174+
qw_reorder,
175+
s_reorder,
176+
zp_reorder,
177+
sm_count,
178+
sm_version,
179+
CUBLAS_M_THRESHOLD,
180+
)
71181

72-
# Marlin quant
73182
(
74183
marlin_w_ref,
75184
marlin_q_w,
76185
marlin_s,
186+
marlin_s2,
187+
marlin_zp,
77188
marlin_g_idx,
78189
marlin_sort_indices,
79-
marlin_rand_perm,
80-
) = marlin_quantize(b, quant_type, group_size, act_order)
81-
82-
# Marlin_24 quant
83-
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
84-
marlin_24_quantize(b, quant_type, group_size)
190+
) = gen_marlin_params()
191+
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
192+
gen_marlin_24_params()
85193
)
86-
87-
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
88-
89-
# GPTQ quant
90-
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
91-
b, quant_type, group_size, act_order
194+
q_w_gptq, repack_sort_indices = gen_repack_params()
195+
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
196+
gen_allspark_params()
92197
)
93-
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
94-
95-
# For act_order, sort the "weights" and "g_idx"
96-
# so that group ids are increasing
97-
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
98-
if act_order:
99-
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
100198

101199
# Prepare
102200
marlin_workspace = MarlinWorkspace(
103201
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
104202
)
105-
106203
marlin_24_workspace = MarlinWorkspace(
107204
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
108205
)
109-
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
110-
111-
# AllSpark W8A16 quant
112-
as_supported_case = (
113-
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
114-
and group_size == -1
115-
and not act_order
116-
and is_k_full
117-
)
118-
if as_supported_case:
119-
properties = torch.cuda.get_device_properties(b.device.index)
120-
sm_count = properties.multi_processor_count
121-
sm_version = properties.major * 10 + properties.minor
122-
123-
supported_arch = sm_version >= 80 and sm_version < 90
124-
as_supported_case = as_supported_case and supported_arch
125-
if supported_arch:
126-
has_zp = False
127-
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
128-
qw = qw.to(torch.uint8)
129-
130-
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
131-
qw, s, zp, has_zp
132-
)
133-
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
134206

135207
globals = {
136208
# Gen params
@@ -140,15 +212,14 @@ def bench_run(
140212
"size_n": size_n,
141213
"size_k": size_k,
142214
"a": a,
143-
"a_tmp": a_tmp,
144215
# Marlin params
145216
"marlin_w_ref": marlin_w_ref,
146217
"marlin_q_w": marlin_q_w,
147218
"marlin_s": marlin_s,
219+
"marlin_s2": marlin_s2,
148220
"marlin_zp": marlin_zp,
149221
"marlin_g_idx": marlin_g_idx,
150222
"marlin_sort_indices": marlin_sort_indices,
151-
"marlin_rand_perm": marlin_rand_perm,
152223
"marlin_workspace": marlin_workspace,
153224
"is_k_full": is_k_full,
154225
# Marlin_24 params
@@ -161,12 +232,12 @@ def bench_run(
161232
"q_w_gptq": q_w_gptq,
162233
"repack_sort_indices": repack_sort_indices,
163234
# AllSpark W8A16 params
164-
"qw_reorder": qw_reorder if as_supported_case else None,
165-
"s_reorder": s_reorder if as_supported_case else None,
166-
"zp_reorder": zp_reorder if as_supported_case else None,
167-
"sm_count": sm_count if as_supported_case else None,
168-
"sm_version": sm_version if as_supported_case else None,
169-
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
235+
"qw_reorder": qw_reorder,
236+
"s_reorder": s_reorder,
237+
"zp_reorder": zp_reorder,
238+
"sm_count": sm_count,
239+
"sm_version": sm_version,
240+
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
170241
# Kernels
171242
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
172243
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
@@ -177,7 +248,7 @@ def bench_run(
177248
min_run_time = 1
178249

179250
# Warmup pytorch
180-
for i in range(5):
251+
for _ in range(5):
181252
torch.matmul(a, marlin_w_ref)
182253

183254
results.append(
@@ -192,28 +263,25 @@ def bench_run(
192263

193264
results.append(
194265
benchmark.Timer(
195-
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
266+
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
196267
globals=globals,
197268
label=label,
198269
sub_label=sub_label,
199-
description="gptq_marlin_gemm_fp16",
270+
description="gptq_marlin_gemm",
200271
).blocked_autorange(min_run_time=min_run_time)
201272
)
202273

203274
results.append(
204275
benchmark.Timer(
205-
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
276+
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
206277
globals=globals,
207278
label=label,
208279
sub_label=sub_label,
209280
description="gptq_marlin_gemm_fp32",
210281
).blocked_autorange(min_run_time=min_run_time)
211282
)
212283

213-
if (
214-
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
215-
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
216-
):
284+
if marlin_24_supported:
217285
results.append(
218286
benchmark.Timer(
219287
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
@@ -224,17 +292,18 @@ def bench_run(
224292
).blocked_autorange(min_run_time=min_run_time)
225293
)
226294

227-
results.append(
228-
benchmark.Timer(
229-
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
230-
globals=globals,
231-
label=label,
232-
sub_label=sub_label,
233-
description="gptq_marlin_repack",
234-
).blocked_autorange(min_run_time=min_run_time)
235-
)
295+
if repack_supported:
296+
results.append(
297+
benchmark.Timer(
298+
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
299+
globals=globals,
300+
label=label,
301+
sub_label=sub_label,
302+
description="gptq_marlin_repack",
303+
).blocked_autorange(min_run_time=min_run_time)
304+
)
236305

237-
if as_supported_case:
306+
if allspark_supported:
238307
results.append(
239308
benchmark.Timer(
240309
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
@@ -250,7 +319,6 @@ def main(args):
250319
print("Benchmarking models:")
251320
for i, model in enumerate(args.models):
252321
print(f"[{i}] {model}")
253-
254322
results: list[benchmark.Measurement] = []
255323

256324
for model in args.models:
@@ -278,14 +346,17 @@ def main(args):
278346
):
279347
continue
280348

281-
for quant_type in query_marlin_supported_quant_types(False):
349+
for quant_type in query_marlin_supported_quant_types():
282350
if (
283351
len(args.limit_num_bits) > 0
284352
and quant_type.size_bits not in args.limit_num_bits
285353
):
286354
continue
287355

288-
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
356+
for group_size in (
357+
MARLIN_SUPPORTED_GROUP_SIZES
358+
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES
359+
):
289360
if (
290361
len(args.limit_group_size) > 0
291362
and group_size not in args.limit_group_size

0 commit comments

Comments
 (0)