2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
4
from dataclasses import dataclass
5
+ from typing import Optional
5
6
6
7
import pytest
7
8
import torch
8
9
import triton .language as tl
9
10
11
+ from tests .kernels .moe .utils import (
12
+ batched_moe ,
13
+ make_test_weights ,
14
+ make_quantized_test_activations ,
15
+ torch_moe2 ,
16
+ triton_moe )
17
+ from tests .kernels .quant_utils import native_w8a8_block_matmul
18
+ from vllm .config import VllmConfig , set_current_vllm_config
10
19
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
11
20
invoke_moe_batched_triton_kernel )
21
+ from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
22
+ from vllm .platforms import current_platform
23
+
24
+ NUM_EXPERTS = [8 , 64 ]
25
+ TOP_KS = [1 , 2 , 6 ]
26
+
27
+ vllm_config = VllmConfig ()
28
+ vllm_config .scheduler_config .max_num_seqs = 128
29
+ vllm_config .scheduler_config .max_model_len = 8192
12
30
13
31
14
32
@dataclass
15
33
class BatchedMMConfig :
16
- dtype : torch .dtype
34
+ in_dtype : torch .dtype
35
+ quant_dtype : Optional [torch .dtype ]
36
+ out_dtype : torch .dtype
17
37
num_experts : int
18
38
max_tokens_per_expert : int
19
39
K : int
@@ -32,84 +52,220 @@ def make_tensors(config: BatchedMMConfig):
32
52
A = torch .randn (
33
53
(config .num_experts , config .max_tokens_per_expert , config .K ),
34
54
device = "cuda" ,
35
- dtype = config .dtype ) / 10
55
+ dtype = config .in_dtype ) / 10
36
56
B = torch .randn ((config .num_experts , config .N , config .K ),
37
57
device = "cuda" ,
38
- dtype = config .dtype )
58
+ dtype = config .in_dtype )
39
59
C = torch .zeros (
40
60
(config .num_experts , config .max_tokens_per_expert , config .N ),
41
61
device = "cuda" ,
42
- dtype = config .dtype )
62
+ dtype = config .out_dtype )
63
+
43
64
num_expert_tokens = torch .randint (low = 0 ,
44
65
high = config .max_tokens_per_expert ,
45
66
size = (config .num_experts , ),
46
67
device = "cuda" ,
47
68
dtype = torch .int32 )
48
- return BatchedMMTensors (A , B , C , num_expert_tokens )
49
69
50
70
51
- def ref_impl (A : torch .Tensor , B : torch .Tensor , C : torch .Tensor ,
52
- num_expert_tokens : torch .Tensor ) -> torch .Tensor :
53
71
72
+ return BatchedMMTensors (A , B , C , num_expert_tokens )
73
+
74
+
75
+ def ref_impl (
76
+ A : torch .Tensor ,
77
+ B : torch .Tensor ,
78
+ C : torch .Tensor ,
79
+ num_expert_tokens : torch .Tensor ,
80
+ A_scale : Optional [torch .Tensor ],
81
+ B_scale : Optional [torch .Tensor ],
82
+ block_shape : Optional [list [int ]],
83
+ ) -> torch .Tensor :
54
84
num_expert_tokens_cpu = num_expert_tokens .clone ()
55
85
num_expert_tokens_cpu = num_expert_tokens_cpu .to (device = "cpu" )
56
86
num_experts = num_expert_tokens .size (0 )
57
87
88
+ f32 = torch .float32
89
+ bf16 = torch .bfloat16
90
+
58
91
for e in range (num_experts ):
59
92
num_tokens = num_expert_tokens_cpu [e ]
60
- C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
93
+ if A .dtype .itemsize == 1 and block_shape is not None :
94
+ tmp = native_w8a8_block_matmul (A [e ], B [e ], A_scale [e ], B_scale [e ],
95
+ block_shape , C .dtype )
96
+ C [e , :num_tokens , :] = tmp [:num_tokens , :]
97
+ elif A .dtype .itemsize == 1 and block_shape is None :
98
+ C [e , :num_tokens , :] = (
99
+ (A [e , :num_tokens , :].to (f32 ) * A_scale [e ]).to (bf16 )
100
+ @ (B [e ].transpose (0 , 1 ).to (f32 ) * B_scale [e ]).to (bf16 ))
101
+ else :
102
+ assert A_scale is None
103
+ assert B_scale is None
104
+ C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
61
105
62
106
return C
63
107
64
108
65
- @pytest .mark .parametrize ("num_experts" , [16 , 32 ])
109
+ @pytest .mark .parametrize ("num_experts" , [8 , 16 , 32 ])
66
110
@pytest .mark .parametrize ("max_tokens_per_expert" ,
67
111
[32 , 64 , 128 , 192 , 224 , 256 , 512 ])
68
112
@pytest .mark .parametrize ("K" , [128 , 256 , 1024 ])
69
113
@pytest .mark .parametrize ("N" , [128 , 256 , 512 , 1024 ])
70
- @pytest .mark .parametrize ("dtype" ,
71
- [torch .float32 , torch .float16 , torch .bfloat16 ])
114
+ @pytest .mark .parametrize (
115
+ "dtype" ,
116
+ [torch .float8_e4m3fn , torch .float32 , torch .float16 , torch .bfloat16 ])
117
+ @pytest .mark .parametrize ("block_shape" , [None ])
118
+ @pytest .mark .parametrize ("per_act_token_quant" , [False ])
72
119
def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
73
- N : int , dtype : torch .dtype ):
120
+ N : int , dtype : torch .dtype , block_shape : Optional [list [int ]],
121
+ per_act_token_quant : bool ):
122
+ current_platform .seed_everything (7 )
74
123
75
- config = BatchedMMConfig (dtype , num_experts , max_tokens_per_expert , K , N )
76
- tensors = BatchedMMTensors .make_tensors (config )
124
+ use_fp8_w8a8 = dtype == torch .float8_e4m3fn
77
125
78
- test_output = tensors .C
79
- ref_output = test_output .clone ()
126
+ if block_shape is not None and not use_fp8_w8a8 :
127
+ pytest .skip ("Don't test blocking for non-quantized types." )
128
+
129
+ if dtype .itemsize == 1 :
130
+ act_dtype = torch .bfloat16
131
+ quant_dtype = dtype
132
+ else :
133
+ act_dtype = dtype
134
+ quant_dtype = None
135
+
136
+ num_expert_tokens = torch .randint (low = 0 ,
137
+ high = max_tokens_per_expert ,
138
+ size = (num_experts , ),
139
+ device = "cuda" ,
140
+ dtype = torch .int32 )
141
+
142
+ A , A_q , A_scale = make_quantized_test_activations (
143
+ num_experts ,
144
+ max_tokens_per_expert ,
145
+ K ,
146
+ in_dtype = act_dtype ,
147
+ quant_dtype = quant_dtype ,
148
+ block_shape = block_shape ,
149
+ per_act_token_quant = per_act_token_quant
150
+ )
151
+
152
+ B , B_q , B_scale , _ , _ , _ = make_test_weights (
153
+ num_experts ,
154
+ N // 2 ,
155
+ K ,
156
+ quant_dtype = dtype ,
157
+ block_shape = block_shape ,
158
+ )
159
+
160
+ out_shape = (num_experts , max_tokens_per_expert , N )
161
+ test_output = torch .zeros (out_shape , dtype = act_dtype , device = "cuda" )
162
+ ref_output = torch .zeros (out_shape , dtype = act_dtype , device = "cuda" )
163
+ q_ref_output = torch .zeros (out_shape , dtype = act_dtype , device = "cuda" )
80
164
81
165
compute_tl_dtype = {
82
166
torch .float16 : tl .float16 ,
83
167
torch .bfloat16 : tl .bfloat16 ,
84
168
torch .float32 : tl .float32
85
169
}[test_output .dtype ]
170
+
86
171
invoke_moe_batched_triton_kernel (
87
- tensors . A ,
88
- tensors . B ,
172
+ A_q ,
173
+ B_q ,
89
174
test_output ,
90
- tensors . num_expert_tokens ,
175
+ num_expert_tokens ,
91
176
compute_tl_dtype ,
92
177
# Quantization data
93
- None ,
94
- None ,
178
+ A_scale ,
179
+ B_scale ,
95
180
None ,
96
181
# Quantization schemes
97
- False ,
182
+ use_fp8_w8a8 ,
98
183
False ,
99
184
False ,
100
185
config = {
101
186
"BLOCK_SIZE_M" : 16 ,
102
187
"BLOCK_SIZE_N" : 16 ,
103
188
"BLOCK_SIZE_K" : 16
104
- })
189
+ },
190
+ block_shape = block_shape ,
191
+ )
192
+
193
+ ref_output = ref_impl (
194
+ A ,
195
+ B ,
196
+ ref_output ,
197
+ num_expert_tokens ,
198
+ None ,
199
+ None ,
200
+ None ,
201
+ )
105
202
106
- ref_output = ref_impl (tensors . A , tensors . B , ref_output ,
107
- tensors . num_expert_tokens )
203
+ q_ref_output = ref_impl (A_q , B_q , q_ref_output , num_expert_tokens , A_scale ,
204
+ B_scale , block_shape )
108
205
109
206
rtol , atol = {
110
207
torch .float16 : (6e-2 , 6e-2 ),
111
208
torch .bfloat16 : (6e-2 , 6e-2 ),
112
209
torch .float32 : (1e-2 , 1e-2 ),
113
210
}[test_output .dtype ]
114
211
115
- torch .testing .assert_close (test_output , ref_output , atol = atol , rtol = rtol )
212
+ torch .testing .assert_close (ref_output , q_ref_output , atol = atol , rtol = rtol )
213
+ torch .testing .assert_close (test_output , q_ref_output , atol = atol , rtol = rtol )
214
+
215
+
216
+ @pytest .mark .parametrize ("m" , [1 , 32 , 45 , 64 , 222 ])
217
+ @pytest .mark .parametrize ("n" , [128 , 512 , 1024 , 2048 ])
218
+ @pytest .mark .parametrize ("k" , [128 , 512 , 1024 , 2048 ])
219
+ @pytest .mark .parametrize ("e" , NUM_EXPERTS )
220
+ @pytest .mark .parametrize ("topk" , TOP_KS )
221
+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
222
+ @pytest .mark .parametrize ("per_act_token_quant" , [False ])
223
+ @pytest .mark .parametrize ("block_shape" , [None ])
224
+ def test_fused_moe_batched_experts (
225
+ m : int ,
226
+ n : int ,
227
+ k : int ,
228
+ e : int ,
229
+ topk : int ,
230
+ dtype : torch .dtype ,
231
+ per_act_token_quant : bool ,
232
+ block_shape : Optional [list [int ]],
233
+ ):
234
+ current_platform .seed_everything (7 )
235
+
236
+ use_fp8_w8a8 = dtype == torch .float8_e4m3fn
237
+ quant_type = torch .float8_e4m3fn if use_fp8_w8a8 else None
238
+
239
+ if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None :
240
+ pytest .skip ("Skip quantization test for non-quantized type" )
241
+
242
+ if per_act_token_quant and block_shape is not None or topk > e :
243
+ pytest .skip ("Skip illegal quantization test" )
244
+
245
+ a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
246
+ score = torch .randn ((m , e ), device = "cuda" , dtype = torch .bfloat16 )
247
+ _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e , n , k , block_shape = block_shape , quant_dtype = dtype )
248
+
249
+ torch .set_printoptions (profile = "full" )
250
+
251
+ with set_current_vllm_config (vllm_config ):
252
+ topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
253
+ batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
254
+ w2_s , quant_type , per_act_token_quant ,
255
+ block_shape )
256
+ baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
257
+ w2_s , quant_type , per_act_token_quant ,
258
+ block_shape )
259
+ triton_output = triton_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
260
+ w2_s , quant_type , per_act_token_quant ,
261
+ block_shape )
262
+
263
+ torch .testing .assert_close (triton_output ,
264
+ baseline_output ,
265
+ atol = 2e-2 ,
266
+ rtol = 2e-2 )
267
+
268
+ torch .testing .assert_close (triton_output ,
269
+ batched_output ,
270
+ atol = 2e-2 ,
271
+ rtol = 2e-2 )
0 commit comments