@@ -77,8 +77,8 @@ def make_tensors(config: BatchedMMConfig):
77
77
@pytest .mark .parametrize (
78
78
"dtype" ,
79
79
[torch .float8_e4m3fn , torch .float32 , torch .float16 , torch .bfloat16 ])
80
- @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]]) # [None])#, [128, 128]])
81
- @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])# [False])# ,True])
80
+ @pytest .mark .parametrize ("block_shape" , [None , [128 , 128 ]])
81
+ @pytest .mark .parametrize ("per_act_token_quant" , [False , True ])
82
82
def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
83
83
N : int , dtype : torch .dtype ,
84
84
block_shape : Optional [list [int ]],
@@ -141,8 +141,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
141
141
142
142
assert A_q .dtype == B_q .dtype
143
143
144
- #B_scale.fill_(0.5)
145
-
146
144
invoke_moe_batched_triton_kernel (
147
145
A_q ,
148
146
B_q ,
@@ -190,7 +188,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
190
188
print (f"REF_OUTPUT { q_ref_output .shape } \n { q_ref_output } " )
191
189
print (f"TRITON { test_output .shape } \n { test_output } " )
192
190
193
- # torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
191
+ torch .testing .assert_close (ref_output , q_ref_output , atol = atol , rtol = rtol )
194
192
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
195
193
torch .testing .assert_close (test_output , q_ref_output , atol = atol , rtol = rtol )
196
194
@@ -246,9 +244,6 @@ def test_fused_moe_batched_experts(
246
244
per_act_token_quant = per_act_token_quant ,
247
245
)
248
246
249
- # TODO remove
250
- torch .set_printoptions (profile = "full" )
251
-
252
247
with set_current_vllm_config (vllm_config ):
253
248
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
254
249
@@ -274,9 +269,9 @@ def test_fused_moe_batched_experts(
274
269
else :
275
270
baseline_output = torch_experts (a , w1_16 , w2_16 , topk_weight , topk_ids )
276
271
277
- # triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
278
- # w2_s, quant_dtype, per_act_token_quant,
279
- # block_shape)
272
+ triton_output = triton_moe (a , w1 , w2 , topk_weight , topk_ids , w1_s ,
273
+ w2_s , quant_dtype , per_act_token_quant ,
274
+ block_shape )
280
275
281
276
#print(f"TORCH {baseline_output.shape}\n{baseline_output}")
282
277
#print(f"TRITON {triton_output.shape}\n{triton_output}")
@@ -292,7 +287,7 @@ def test_fused_moe_batched_experts(
292
287
# atol=2e-2,
293
288
# rtol=2e-2)
294
289
295
- # torch.testing.assert_close(triton_output,
296
- # batched_output,
297
- # atol=2e-2,
298
- # rtol=2e-2)
290
+ torch .testing .assert_close (triton_output ,
291
+ batched_output ,
292
+ atol = 2e-2 ,
293
+ rtol = 2e-2 )
0 commit comments