22
22
MARLIN_SUPPORTED_GROUP_SIZES ,
23
23
query_marlin_supported_quant_types ,
24
24
)
25
+ from vllm .model_executor .layers .quantization .utils .marlin_utils_fp4 import (
26
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES ,
27
+ rand_marlin_weight_fp4_like ,
28
+ )
29
+ from vllm .model_executor .layers .quantization .utils .marlin_utils_fp8 import (
30
+ marlin_quant_fp8_torch ,
31
+ )
25
32
from vllm .model_executor .layers .quantization .utils .marlin_utils_test import (
26
33
MarlinWorkspace ,
34
+ awq_marlin_quantize ,
27
35
marlin_quantize ,
28
36
)
29
37
from vllm .model_executor .layers .quantization .utils .marlin_utils_test_24 import (
35
43
quantize_weights ,
36
44
sort_weights ,
37
45
)
38
- from vllm .scalar_type import ScalarType
46
+ from vllm .scalar_type import ScalarType , scalar_types
39
47
from vllm .utils import FlexibleArgumentParser
40
48
41
49
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1" ]
@@ -57,80 +65,144 @@ def bench_run(
57
65
size_n : int ,
58
66
):
59
67
label = "Quant Matmul"
60
-
61
68
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})" .format (
62
69
model , act_order , is_k_full , str (quant_type ), group_size , size_m , size_k , size_n
63
70
)
64
-
65
71
print (f"Testing: { sub_label } " )
66
72
67
73
a = torch .randn (size_m , size_k ).to (torch .half ).cuda ()
68
74
b = torch .rand (size_k , size_n ).to (torch .half ).cuda ()
75
+ has_zp = quant_type in [scalar_types .uint4 , scalar_types .uint8 ]
76
+ if act_order and (group_size == - 1 or group_size == size_k or has_zp ):
77
+ return
78
+ if size_k % group_size != 0 :
79
+ return
69
80
70
- a_tmp = torch .zeros (size_m , size_k ).to (torch .half ).cuda ()
81
+ marlin_24_supported = (
82
+ quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
83
+ and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
84
+ )
85
+ repack_supported = (
86
+ quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
87
+ and group_size in MARLIN_SUPPORTED_GROUP_SIZES
88
+ )
89
+ allspark_supported = (
90
+ quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
91
+ and group_size == - 1
92
+ and not act_order
93
+ and is_k_full
94
+ )
95
+
96
+ def gen_marlin_params ():
97
+ # Marlin quant
98
+ marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
99
+ if quant_type == scalar_types .float4_e2m1f :
100
+ if group_size != 16 or act_order :
101
+ return
102
+ marlin_w_ref , marlin_q_w , marlin_s , marlin_s2 = rand_marlin_weight_fp4_like (
103
+ b .T , group_size
104
+ )
105
+ elif quant_type == scalar_types .float8_e4m3fn :
106
+ if group_size not in [- 1 , 128 ] or act_order :
107
+ return
108
+ marlin_w_ref , marlin_q_w , marlin_s = marlin_quant_fp8_torch (b .T , group_size )
109
+ elif group_size == 16 :
110
+ return
111
+ elif has_zp :
112
+ marlin_w_ref , marlin_q_w , marlin_s , marlin_zp = awq_marlin_quantize (
113
+ b , quant_type , group_size
114
+ )
115
+ else :
116
+ marlin_w_ref , marlin_q_w , marlin_s , marlin_g_idx , marlin_sort_indices , _ = (
117
+ marlin_quantize (b , quant_type , group_size , act_order )
118
+ )
119
+ return (
120
+ marlin_w_ref ,
121
+ marlin_q_w ,
122
+ marlin_s ,
123
+ marlin_s2 ,
124
+ marlin_zp ,
125
+ marlin_g_idx ,
126
+ marlin_sort_indices ,
127
+ )
128
+
129
+ def gen_marlin_24_params ():
130
+ marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
131
+ if marlin_24_supported :
132
+ (marlin_24_w_ref , marlin_24_q_w_comp , marlin_24_meta , marlin_24_s ) = (
133
+ marlin_24_quantize (b , quant_type , group_size )
134
+ )
135
+ return (marlin_24_w_ref , marlin_24_q_w_comp , marlin_24_meta , marlin_24_s )
136
+
137
+ def gen_repack_params ():
138
+ q_w_gptq = None
139
+ repack_sort_indices = None
140
+ if repack_supported :
141
+ (w_ref , q_w , s , g_idx , rand_perm ) = gptq_quantize_weights (
142
+ b , quant_type , group_size , act_order
143
+ )
144
+ q_w_gptq = gptq_pack (q_w , quant_type .size_bits , size_k , size_n )
145
+
146
+ # For act_order, sort the "weights" and "g_idx"
147
+ # so that group ids are increasing
148
+ repack_sort_indices = torch .empty (0 , dtype = torch .int , device = b .device )
149
+ if act_order :
150
+ (q_w , g_idx , repack_sort_indices ) = sort_weights (q_w , g_idx )
151
+ return q_w_gptq , repack_sort_indices
152
+
153
+ def gen_allspark_params ():
154
+ qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
155
+ CUBLAS_M_THRESHOLD
156
+ ) = None
157
+ nonlocal allspark_supported
158
+ if allspark_supported :
159
+ properties = torch .cuda .get_device_properties (b .device .index )
160
+ sm_count = properties .multi_processor_count
161
+ sm_version = properties .major * 10 + properties .minor
162
+
163
+ supported_arch = sm_version >= 80 and sm_version < 90
164
+ allspark_supported = allspark_supported and supported_arch
165
+ if supported_arch :
166
+ w_ref , qw , s , zp = quantize_weights (b , quant_type , group_size , has_zp )
167
+ qw = qw .to (torch .uint8 )
168
+
169
+ qw_reorder , s_reorder , zp_reorder = ops .allspark_repack_weight (
170
+ qw , s , zp , has_zp
171
+ )
172
+ CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
173
+ return (
174
+ qw_reorder ,
175
+ s_reorder ,
176
+ zp_reorder ,
177
+ sm_count ,
178
+ sm_version ,
179
+ CUBLAS_M_THRESHOLD ,
180
+ )
71
181
72
- # Marlin quant
73
182
(
74
183
marlin_w_ref ,
75
184
marlin_q_w ,
76
185
marlin_s ,
186
+ marlin_s2 ,
187
+ marlin_zp ,
77
188
marlin_g_idx ,
78
189
marlin_sort_indices ,
79
- marlin_rand_perm ,
80
- ) = marlin_quantize (b , quant_type , group_size , act_order )
81
-
82
- # Marlin_24 quant
83
- (marlin_24_w_ref , marlin_24_q_w_comp , marlin_24_meta , marlin_24_s ) = (
84
- marlin_24_quantize (b , quant_type , group_size )
190
+ ) = gen_marlin_params ()
191
+ marlin_24_w_ref , marlin_24_q_w_comp , marlin_24_meta , marlin_24_s = (
192
+ gen_marlin_24_params ()
85
193
)
86
-
87
- marlin_zp = torch .empty (0 , dtype = torch .int , device = b .device )
88
-
89
- # GPTQ quant
90
- (w_ref , q_w , s , g_idx , rand_perm ) = gptq_quantize_weights (
91
- b , quant_type , group_size , act_order
194
+ q_w_gptq , repack_sort_indices = gen_repack_params ()
195
+ qw_reorder , s_reorder , zp_reorder , sm_count , sm_version , CUBLAS_M_THRESHOLD = (
196
+ gen_allspark_params ()
92
197
)
93
- q_w_gptq = gptq_pack (q_w , quant_type .size_bits , size_k , size_n )
94
-
95
- # For act_order, sort the "weights" and "g_idx"
96
- # so that group ids are increasing
97
- repack_sort_indices = torch .empty (0 , dtype = torch .int , device = b .device )
98
- if act_order :
99
- (q_w , g_idx , repack_sort_indices ) = sort_weights (q_w , g_idx )
100
198
101
199
# Prepare
102
200
marlin_workspace = MarlinWorkspace (
103
201
size_n , GPTQ_MARLIN_MIN_THREAD_N , GPTQ_MARLIN_MAX_PARALLEL
104
202
)
105
-
106
203
marlin_24_workspace = MarlinWorkspace (
107
204
size_n , GPTQ_MARLIN_24_MIN_THREAD_N , GPTQ_MARLIN_24_MAX_PARALLEL
108
205
)
109
- marlin_zp = torch .zeros_like (marlin_s , dtype = torch .int )
110
-
111
- # AllSpark W8A16 quant
112
- as_supported_case = (
113
- quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
114
- and group_size == - 1
115
- and not act_order
116
- and is_k_full
117
- )
118
- if as_supported_case :
119
- properties = torch .cuda .get_device_properties (b .device .index )
120
- sm_count = properties .multi_processor_count
121
- sm_version = properties .major * 10 + properties .minor
122
-
123
- supported_arch = sm_version >= 80 and sm_version < 90
124
- as_supported_case = as_supported_case and supported_arch
125
- if supported_arch :
126
- has_zp = False
127
- w_ref , qw , s , zp = quantize_weights (b , quant_type , group_size , has_zp )
128
- qw = qw .to (torch .uint8 )
129
-
130
- qw_reorder , s_reorder , zp_reorder = ops .allspark_repack_weight (
131
- qw , s , zp , has_zp
132
- )
133
- CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
134
206
135
207
globals = {
136
208
# Gen params
@@ -140,15 +212,14 @@ def bench_run(
140
212
"size_n" : size_n ,
141
213
"size_k" : size_k ,
142
214
"a" : a ,
143
- "a_tmp" : a_tmp ,
144
215
# Marlin params
145
216
"marlin_w_ref" : marlin_w_ref ,
146
217
"marlin_q_w" : marlin_q_w ,
147
218
"marlin_s" : marlin_s ,
219
+ "marlin_s2" : marlin_s2 ,
148
220
"marlin_zp" : marlin_zp ,
149
221
"marlin_g_idx" : marlin_g_idx ,
150
222
"marlin_sort_indices" : marlin_sort_indices ,
151
- "marlin_rand_perm" : marlin_rand_perm ,
152
223
"marlin_workspace" : marlin_workspace ,
153
224
"is_k_full" : is_k_full ,
154
225
# Marlin_24 params
@@ -161,12 +232,12 @@ def bench_run(
161
232
"q_w_gptq" : q_w_gptq ,
162
233
"repack_sort_indices" : repack_sort_indices ,
163
234
# AllSpark W8A16 params
164
- "qw_reorder" : qw_reorder if as_supported_case else None ,
165
- "s_reorder" : s_reorder if as_supported_case else None ,
166
- "zp_reorder" : zp_reorder if as_supported_case else None ,
167
- "sm_count" : sm_count if as_supported_case else None ,
168
- "sm_version" : sm_version if as_supported_case else None ,
169
- "CUBLAS_M_THRESHOLD" : CUBLAS_M_THRESHOLD if as_supported_case else None ,
235
+ "qw_reorder" : qw_reorder ,
236
+ "s_reorder" : s_reorder ,
237
+ "zp_reorder" : zp_reorder ,
238
+ "sm_count" : sm_count ,
239
+ "sm_version" : sm_version ,
240
+ "CUBLAS_M_THRESHOLD" : CUBLAS_M_THRESHOLD ,
170
241
# Kernels
171
242
"gptq_marlin_gemm" : ops .gptq_marlin_gemm ,
172
243
"gptq_marlin_24_gemm" : ops .gptq_marlin_24_gemm ,
@@ -177,7 +248,7 @@ def bench_run(
177
248
min_run_time = 1
178
249
179
250
# Warmup pytorch
180
- for i in range (5 ):
251
+ for _ in range (5 ):
181
252
torch .matmul (a , marlin_w_ref )
182
253
183
254
results .append (
@@ -192,28 +263,25 @@ def bench_run(
192
263
193
264
results .append (
194
265
benchmark .Timer (
195
- stmt = "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)" , # noqa: E501
266
+ stmt = "output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2 , marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)" , # noqa: E501
196
267
globals = globals ,
197
268
label = label ,
198
269
sub_label = sub_label ,
199
- description = "gptq_marlin_gemm_fp16 " ,
270
+ description = "gptq_marlin_gemm " ,
200
271
).blocked_autorange (min_run_time = min_run_time )
201
272
)
202
273
203
274
results .append (
204
275
benchmark .Timer (
205
- stmt = "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)" , # noqa: E501
276
+ stmt = "output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2 , marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)" , # noqa: E501
206
277
globals = globals ,
207
278
label = label ,
208
279
sub_label = sub_label ,
209
280
description = "gptq_marlin_gemm_fp32" ,
210
281
).blocked_autorange (min_run_time = min_run_time )
211
282
)
212
283
213
- if (
214
- quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
215
- and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
216
- ):
284
+ if marlin_24_supported :
217
285
results .append (
218
286
benchmark .Timer (
219
287
stmt = "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)" , # noqa: E501
@@ -224,17 +292,18 @@ def bench_run(
224
292
).blocked_autorange (min_run_time = min_run_time )
225
293
)
226
294
227
- results .append (
228
- benchmark .Timer (
229
- stmt = "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)" , # noqa: E501
230
- globals = globals ,
231
- label = label ,
232
- sub_label = sub_label ,
233
- description = "gptq_marlin_repack" ,
234
- ).blocked_autorange (min_run_time = min_run_time )
235
- )
295
+ if repack_supported :
296
+ results .append (
297
+ benchmark .Timer (
298
+ stmt = "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)" , # noqa: E501
299
+ globals = globals ,
300
+ label = label ,
301
+ sub_label = sub_label ,
302
+ description = "gptq_marlin_repack" ,
303
+ ).blocked_autorange (min_run_time = min_run_time )
304
+ )
236
305
237
- if as_supported_case :
306
+ if allspark_supported :
238
307
results .append (
239
308
benchmark .Timer (
240
309
stmt = "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)" , # noqa: E501
@@ -250,7 +319,6 @@ def main(args):
250
319
print ("Benchmarking models:" )
251
320
for i , model in enumerate (args .models ):
252
321
print (f"[{ i } ] { model } " )
253
-
254
322
results : list [benchmark .Measurement ] = []
255
323
256
324
for model in args .models :
@@ -278,14 +346,17 @@ def main(args):
278
346
):
279
347
continue
280
348
281
- for quant_type in query_marlin_supported_quant_types (False ):
349
+ for quant_type in query_marlin_supported_quant_types ():
282
350
if (
283
351
len (args .limit_num_bits ) > 0
284
352
and quant_type .size_bits not in args .limit_num_bits
285
353
):
286
354
continue
287
355
288
- for group_size in MARLIN_SUPPORTED_GROUP_SIZES :
356
+ for group_size in (
357
+ MARLIN_SUPPORTED_GROUP_SIZES
358
+ + FP4_MARLIN_SUPPORTED_GROUP_SIZES
359
+ ):
289
360
if (
290
361
len (args .limit_group_size ) > 0
291
362
and group_size not in args .limit_group_size
0 commit comments