@@ -133,6 +133,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
133
133
act_dtype = dtype
134
134
quant_dtype = None
135
135
136
+ #print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}")
137
+
136
138
num_expert_tokens = torch .randint (low = 0 ,
137
139
high = max_tokens_per_expert ,
138
140
size = (num_experts , ),
@@ -153,7 +155,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
153
155
num_experts ,
154
156
N // 2 ,
155
157
K ,
156
- quant_dtype = dtype ,
158
+ in_dtype = act_dtype ,
159
+ quant_dtype = quant_dtype ,
157
160
block_shape = block_shape ,
158
161
)
159
162
@@ -168,6 +171,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
168
171
torch .float32 : tl .float32
169
172
}[test_output .dtype ]
170
173
174
+ assert A_q .dtype == B_q .dtype
175
+
171
176
invoke_moe_batched_triton_kernel (
172
177
A_q ,
173
178
B_q ,
@@ -185,7 +190,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
185
190
config = {
186
191
"BLOCK_SIZE_M" : 16 ,
187
192
"BLOCK_SIZE_N" : 16 ,
188
- "BLOCK_SIZE_K" : 16
193
+ "BLOCK_SIZE_K" : 16 if dtype . itemsize > 1 else 32
189
194
},
190
195
block_shape = block_shape ,
191
196
)
@@ -209,7 +214,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
209
214
torch .float32 : (1e-2 , 1e-2 ),
210
215
}[test_output .dtype ]
211
216
212
- torch .testing .assert_close (ref_output , q_ref_output , atol = atol , rtol = rtol )
217
+ torch .testing .assert_close (ref_output , test_output , atol = atol , rtol = rtol )
213
218
torch .testing .assert_close (test_output , q_ref_output , atol = atol , rtol = rtol )
214
219
215
220
@@ -234,7 +239,6 @@ def test_fused_moe_batched_experts(
234
239
current_platform .seed_everything (7 )
235
240
236
241
use_fp8_w8a8 = dtype == torch .float8_e4m3fn
237
- quant_type = torch .float8_e4m3fn if use_fp8_w8a8 else None
238
242
239
243
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None :
240
244
pytest .skip ("Skip quantization test for non-quantized type" )
@@ -244,20 +248,30 @@ def test_fused_moe_batched_experts(
244
248
245
249
a = torch .randn ((m , k ), device = "cuda" , dtype = torch .bfloat16 ) / 10
246
250
score = torch .randn ((m , e ), device = "cuda" , dtype = torch .bfloat16 )
247
- _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e , n , k , block_shape = block_shape , quant_dtype = dtype )
251
+
252
+ if dtype .itemsize == 1 :
253
+ act_dtype = torch .bfloat16
254
+ quant_dtype = dtype
255
+ else :
256
+ act_dtype = dtype
257
+ quant_dtype = None
258
+
259
+ _ , w1 , w1_s , _ , w2 , w2_s = make_test_weights (e , n , k , block_shape = block_shape ,
260
+ in_dtype = act_dtype ,
261
+ quant_dtype = quant_dtype )
248
262
249
263
torch .set_printoptions (profile = "full" )
250
264
251
265
with set_current_vllm_config (vllm_config ):
252
266
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
253
267
batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
254
- w2_s , quant_type , per_act_token_quant ,
268
+ w2_s , quant_dtype , per_act_token_quant ,
255
269
block_shape )
256
270
baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
257
- w2_s , quant_type , per_act_token_quant ,
271
+ w2_s , quant_dtype , per_act_token_quant ,
258
272
block_shape )
259
273
triton_output = triton_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
260
- w2_s , quant_type , per_act_token_quant ,
274
+ w2_s , quant_dtype , per_act_token_quant ,
261
275
block_shape )
262
276
263
277
torch .testing .assert_close (triton_output ,
0 commit comments