Skip to content

Commit 3d226f5

Browse files
committed
both models work
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 39055b0 commit 3d226f5

File tree

12 files changed

+308
-147
lines changed

12 files changed

+308
-147
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
184184
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
185185
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
186186

187+
# @pytest.mark.parametrize("m", [6, 16, 199, 200, 256])
188+
# @pytest.mark.parametrize("n", [2816//2])
189+
# @pytest.mark.parametrize("k", [2048])
190+
# @pytest.mark.parametrize("e", [32])
191+
# @pytest.mark.parametrize("topk", [6])
192+
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
193+
# @pytest.mark.parametrize("per_act_token_quant", [False])
194+
# @pytest.mark.parametrize("block_shape", [None])
187195

188196
@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
189197
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
@@ -193,6 +201,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
193201
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
194202
@pytest.mark.parametrize("per_act_token_quant", [False, True])
195203
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
204+
@pytest.mark.parametrize("input_scales", [False])
196205
def test_fused_moe_batched_experts(
197206
m: int,
198207
n: int,
@@ -202,6 +211,7 @@ def test_fused_moe_batched_experts(
202211
dtype: torch.dtype,
203212
per_act_token_quant: bool,
204213
block_shape: Optional[list[int]],
214+
input_scales: bool,
205215
):
206216
current_platform.seed_everything(7)
207217

@@ -236,13 +246,16 @@ def test_fused_moe_batched_experts(
236246
per_act_token_quant=per_act_token_quant,
237247
)
238248

249+
if input_scales and quant_dtype is not None:
250+
a1_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
251+
a2_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
252+
else:
253+
a1_scale = None
254+
a2_scale = None
255+
239256
with set_current_vllm_config(vllm_config):
240257
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
241258

242-
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
243-
w2_s, quant_dtype, per_act_token_quant,
244-
block_shape)
245-
246259
baseline_output = torch_experts(
247260
a,
248261
w1,
@@ -251,13 +264,42 @@ def test_fused_moe_batched_experts(
251264
topk_ids,
252265
w1_scale=w1_s,
253266
w2_scale=w2_s,
267+
a1_scale=a1_scale,
268+
a2_scale=a2_scale,
269+
quant_dtype=quant_dtype,
270+
per_act_token_quant=per_act_token_quant,
271+
block_shape=block_shape,
272+
)
273+
274+
batched_output = naive_batched_moe(
275+
a,
276+
w1,
277+
w2,
278+
topk_weight,
279+
topk_ids,
280+
w1_scale=w1_s,
281+
w2_scale=w2_s,
282+
a1_scale=a1_scale,
283+
a2_scale=a2_scale,
254284
quant_dtype=quant_dtype,
255285
per_act_token_quant=per_act_token_quant,
256-
block_shape=block_shape)
286+
block_shape=block_shape,
287+
)
257288

258-
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
259-
w2_s, quant_dtype, per_act_token_quant,
260-
block_shape)
289+
triton_output = batched_moe(
290+
a,
291+
w1,
292+
w2,
293+
topk_weight,
294+
topk_ids,
295+
w1_scale=w1_s,
296+
w2_scale=w2_s,
297+
a1_scale=a1_scale,
298+
a2_scale=a2_scale,
299+
quant_dtype=quant_dtype,
300+
per_act_token_quant=per_act_token_quant,
301+
block_shape=block_shape,
302+
)
261303

262304
torch.testing.assert_close(batched_output,
263305
baseline_output,

tests/kernels/moe/test_pplx_moe.py

Lines changed: 113 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,24 @@
4040
reason="Requires PPLX kernels",
4141
)
4242

43-
PPLX_PREPARE_COMBOS = [
43+
PPLX_COMBOS = [
4444
# TODO: figure out why this fails
4545
#(1, 128, 128),
46+
4647
(2, 128, 512),
4748
(3, 1024, 2048),
4849
(4, 128, 128),
4950
(32, 1024, 512),
5051
(45, 512, 2048),
5152
(64, 1024, 512),
5253
(222, 2048, 1024),
53-
]
54+
(256, 1408, 2048),
5455

55-
PPLX_MOE_COMBOS = [
56-
# (1, 128, 128),
57-
(2, 128, 512),
58-
(3, 1024, 2048),
59-
(32, 128, 1024),
60-
(45, 512, 2048),
61-
(64, 1024, 1024),
62-
(222, 1024, 2048),
56+
#(6, 1408, 2048),
57+
#(16, 1408, 2048),
58+
#(199, 1408, 2048),
59+
#(200, 1408, 2048),
60+
#(256, 1408, 2048),
6361
]
6462

6563
NUM_EXPERTS = [8, 64]
@@ -280,10 +278,17 @@ def pplx_prepare_finalize(
280278
device=device,
281279
)
282280

281+
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
282+
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
283+
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
284+
else:
285+
a1_scale = None
286+
a2_scale = None
287+
283288
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
284289
a_chunk,
285-
None,
286-
None,
290+
a1_scale,
291+
a2_scale,
287292
chunk_topk_weight,
288293
chunk_topk_ids,
289294
num_experts,
@@ -364,7 +369,7 @@ def _pplx_prepare_finalize(
364369

365370
# TODO (bnell): this test point does not work for M==1 due to how the test
366371
# is written, not due to limitations of the pplx kernels.
367-
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
372+
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
368373
@pytest.mark.parametrize("e", NUM_EXPERTS)
369374
@pytest.mark.parametrize("topk", TOP_KS)
370375
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@@ -423,7 +428,9 @@ def pplx_moe(
423428
topk_ids: torch.Tensor,
424429
w1_scale: Optional[torch.Tensor] = None,
425430
w2_scale: Optional[torch.Tensor] = None,
426-
qtype: Optional[torch.dtype] = None,
431+
a1_scale: Optional[torch.Tensor] = None,
432+
a2_scale: Optional[torch.Tensor] = None,
433+
quant_dtype: Optional[torch.dtype] = None,
427434
per_act_token_quant=False,
428435
block_shape: Optional[list[int]] = None,
429436
use_compile: bool = False,
@@ -442,7 +449,7 @@ def pplx_moe(
442449
max_num_tokens,
443450
hidden_dim,
444451
a.dtype,
445-
qtype,
452+
quant_dtype,
446453
per_act_token_quant=per_act_token_quant,
447454
block_shape=block_shape,
448455
)
@@ -478,7 +485,7 @@ def pplx_moe(
478485
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
479486
world_size=world_size,
480487
dp_size=dp_size,
481-
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
488+
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
482489
block_shape=block_shape)
483490

484491
fused_experts = FusedMoEModularKernel(
@@ -496,12 +503,8 @@ def pplx_moe(
496503
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
497504

498505
if w1_scale is not None:
499-
if not per_act_token_quant:
500-
w1_scale_chunk = w1_scale
501-
w2_scale_chunk = w2_scale
502-
else:
503-
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
504-
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
506+
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
507+
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
505508
else:
506509
w1_scale_chunk = None
507510
w2_scale_chunk = None
@@ -526,6 +529,8 @@ def pplx_moe(
526529
chunk_topk_ids,
527530
w1_scale=w1_scale_chunk,
528531
w2_scale=w2_scale_chunk,
532+
a1_scale=a1_scale,
533+
a2_scale=a2_scale,
529534
global_num_experts=num_experts)
530535

531536
if use_cudagraphs:
@@ -540,6 +545,8 @@ def pplx_moe(
540545
chunk_topk_ids,
541546
w1_scale=w1_scale_chunk,
542547
w2_scale=w2_scale_chunk,
548+
a1_scale=a1_scale,
549+
a2_scale=a2_scale,
543550
global_num_experts=num_experts)
544551

545552
torch.cuda.synchronize()
@@ -562,7 +569,7 @@ def _pplx_moe(
562569
topk: int,
563570
w1_s: Optional[torch.Tensor] = None,
564571
w2_s: Optional[torch.Tensor] = None,
565-
qtype: Optional[torch.dtype] = None,
572+
quant_dtype: Optional[torch.dtype] = None,
566573
per_act_token_quant: bool = False,
567574
block_shape: Optional[list[int]] = None,
568575
use_internode: bool = False,
@@ -584,48 +591,112 @@ def _pplx_moe(
584591
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
585592

586593
device = torch.device("cuda", pgi.rank)
594+
rank = pgi.rank
595+
world_size = pgi.world_size
587596
a = a.to(device)
588597
w1 = w1.to(device)
589598
w2 = w2.to(device)
590599
w1_s = w1_s.to(device) if w1_s is not None else None
591600
w2_s = w2_s.to(device) if w2_s is not None else None
592601

602+
if quant_dtype is not None and not per_act_token_quant and block_shape is None:
603+
a1_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
604+
a2_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
605+
else:
606+
a1_scale = None
607+
a2_scale = None
608+
593609
with set_current_vllm_config(vllm_config), override_config(moe_config):
594610
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
595-
torch_output = torch_experts(a,
596-
w1,
597-
w2,
598-
topk_weight,
599-
topk_ids,
600-
w1_scale=w1_s,
601-
w2_scale=w2_s,
602-
quant_dtype=qtype,
603-
per_act_token_quant=per_act_token_quant,
604-
block_shape=block_shape)
605-
606-
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
607-
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
608-
qtype, per_act_token_quant, block_shape)
611+
612+
if False:
613+
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
614+
topk_weight_chunk = chunk_by_rank(topk_weight, rank, world_size).to(device)
615+
topk_ids_chunk = chunk_by_rank(topk_ids, rank, world_size).to(device)
616+
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
617+
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
618+
619+
if w1_s is not None:
620+
w1_s_chunk = chunk_by_rank(w1_s, rank, world_size).to(device)
621+
w2_s_chunk = chunk_by_rank(w2_s, rank, world_size).to(device)
622+
else:
623+
w1_s_chunk = None
624+
w2_s_chunk = None
625+
else:
626+
a_chunk = a
627+
topk_weight_chunk = topk_weight
628+
topk_ids_chunk = topk_ids
629+
w1_chunk = w1
630+
w2_chunk = w2
631+
w1_s_chunk = w1_s
632+
w2_s_chunk = w2_s
633+
634+
torch_output = torch_experts(
635+
a_chunk,
636+
w1_chunk,
637+
w2_chunk,
638+
topk_weight_chunk,
639+
topk_ids_chunk,
640+
w1_scale=w1_s_chunk,
641+
w2_scale=w2_s_chunk,
642+
a1_scale=a1_scale,
643+
a2_scale=a2_scale,
644+
quant_dtype=quant_dtype,
645+
per_act_token_quant=per_act_token_quant,
646+
block_shape=block_shape,
647+
)
648+
649+
batched_output = naive_batched_moe(
650+
a_chunk,
651+
w1_chunk,
652+
w2_chunk,
653+
topk_weight_chunk,
654+
topk_ids_chunk,
655+
w1_scale=w1_s_chunk,
656+
w2_scale=w2_s_chunk,
657+
a1_scale=a1_scale,
658+
a2_scale=a2_scale,
659+
quant_dtype=quant_dtype,
660+
per_act_token_quant=per_act_token_quant,
661+
block_shape=block_shape,
662+
)
663+
664+
pplx_output = pplx_moe(
665+
group_name,
666+
rank,
667+
world_size,
668+
dp_size,
669+
a,
670+
w1,
671+
w2,
672+
topk_weight,
673+
topk_ids,
674+
w1_scale=w1_s,
675+
w2_scale=w2_s,
676+
a1_scale=a1_scale,
677+
a2_scale=a2_scale,
678+
quant_dtype=quant_dtype,
679+
per_act_token_quant=per_act_token_quant,
680+
block_shape=block_shape)
609681

610682
# all reduce on pplx?
611683
#torch.distributed.all_reduce(pplx_output)
612684

613-
batched_output = naive_batched_moe(a, w1, w2, topk_weight,
614-
topk_ids, w1_s, w2_s, qtype, per_act_token_quant, block_shape)
615-
616685
chunked_torch_output = chunk_by_rank(torch_output, pgi.rank,
617686
pgi.world_size).to(pplx_output.device)
618687

619688
torch.testing.assert_close(pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2)
620-
#torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
689+
torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
621690

622691
if use_internode:
623692
nvshmem_finalize()
624693

625694

626-
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
695+
@pytest.mark.parametrize("mnk", PPLX_COMBOS)
627696
@pytest.mark.parametrize("e", NUM_EXPERTS)
628697
@pytest.mark.parametrize("topk", TOP_KS)
698+
#@pytest.mark.parametrize("e", [32])
699+
#@pytest.mark.parametrize("topk", [6])
629700
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
630701
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
631702
@pytest.mark.parametrize("per_act_token_quant", [False, True])

0 commit comments

Comments
 (0)