Skip to content

Commit 39055b0

Browse files
committed
qwen works, rh-ds broken now, pplx_moe tests not all working
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent f92734e commit 39055b0

File tree

12 files changed

+158
-344
lines changed

12 files changed

+158
-344
lines changed

tests/kernels/moe/test_batched_moe.py

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
100100
act_dtype = dtype
101101
quant_dtype = None
102102

103-
#print(f"TYPES {dtype}, {act_dtype}, {quant_dtype}")
104-
105103
num_expert_tokens = torch.randint(low=0,
106104
high=max_tokens_per_expert,
107105
size=(num_experts, ),
@@ -183,19 +181,13 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
183181
torch.float32: (1e-2, 1e-2),
184182
}[test_output.dtype]
185183

186-
if False:
187-
torch.set_printoptions(profile="full")
188-
print(f"REF_OUTPUT {q_ref_output.shape}\n{q_ref_output}")
189-
print(f"TRITON {test_output.shape}\n{test_output}")
190-
191184
torch.testing.assert_close(ref_output, q_ref_output, atol=atol, rtol=rtol)
192-
#torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
193185
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
194186

195187

196188
@pytest.mark.parametrize("m", [1, 32, 45, 64, 222])
197-
@pytest.mark.parametrize("n", [128, 512, 1024])#, 2048])
198-
@pytest.mark.parametrize("k", [128, 512, 1024])#, 2048])
189+
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
190+
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
199191
@pytest.mark.parametrize("e", NUM_EXPERTS)
200192
@pytest.mark.parametrize("topk", TOP_KS)
201193
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@@ -221,7 +213,7 @@ def test_fused_moe_batched_experts(
221213
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
222214
pytest.skip("Skip quantization test for non-quantized type")
223215

224-
if (per_act_token_quant and block_shape is not None) or topk > e:
216+
if per_act_token_quant and block_shape is not None:
225217
pytest.skip("Skip illegal quantization test")
226218

227219
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
@@ -247,46 +239,31 @@ def test_fused_moe_batched_experts(
247239
with set_current_vllm_config(vllm_config):
248240
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
249241

250-
if True:
251-
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
252-
w2_s, quant_dtype, per_act_token_quant,
253-
block_shape)
254-
else:
255-
batched_output = naive_batched_moe(a, w1_16, w2_16, topk_weight, topk_ids)
256-
257-
if True:
258-
baseline_output = torch_experts(
259-
a,
260-
w1,
261-
w2,
262-
topk_weight,
263-
topk_ids,
264-
w1_scale=w1_s,
265-
w2_scale=w2_s,
266-
quant_dtype=quant_dtype,
267-
per_act_token_quant=per_act_token_quant,
268-
block_shape=block_shape)
269-
else:
270-
baseline_output = torch_experts(a, w1_16, w2_16, topk_weight, topk_ids)
242+
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
243+
w2_s, quant_dtype, per_act_token_quant,
244+
block_shape)
245+
246+
baseline_output = torch_experts(
247+
a,
248+
w1,
249+
w2,
250+
topk_weight,
251+
topk_ids,
252+
w1_scale=w1_s,
253+
w2_scale=w2_s,
254+
quant_dtype=quant_dtype,
255+
per_act_token_quant=per_act_token_quant,
256+
block_shape=block_shape)
271257

272258
triton_output = triton_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
273259
w2_s, quant_dtype, per_act_token_quant,
274260
block_shape)
275261

276-
#print(f"TORCH {baseline_output.shape}\n{baseline_output}")
277-
#print(f"TRITON {triton_output.shape}\n{triton_output}")
278-
#print(f"BATCHED {batched_output.shape}\n{batched_output}")
279-
280262
torch.testing.assert_close(batched_output,
281263
baseline_output,
282264
atol=3e-2,
283265
rtol=2e-2)
284266

285-
# torch.testing.assert_close(triton_output,
286-
# baseline_output,
287-
# atol=2e-2,
288-
# rtol=2e-2)
289-
290267
torch.testing.assert_close(triton_output,
291268
batched_output,
292269
atol=2e-2,

tests/kernels/moe/test_pplx_moe.py

Lines changed: 42 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
2222
from tests.kernels.utils import torch_experts
23-
from tests.kernels.quant_utils import dequant
23+
from tests.kernels.quant_utils import batched_dequant
2424
from vllm.config import VllmConfig, set_current_vllm_config
2525
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
2626
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
@@ -41,16 +41,19 @@
4141
)
4242

4343
PPLX_PREPARE_COMBOS = [
44-
# (1, 128, 128),
44+
# TODO: figure out why this fails
45+
#(1, 128, 128),
46+
(2, 128, 512),
47+
(3, 1024, 2048),
4548
(4, 128, 128),
4649
(32, 1024, 512),
47-
# (45, 512, 2048),
50+
(45, 512, 2048),
4851
(64, 1024, 512),
4952
(222, 2048, 1024),
5053
]
5154

5255
PPLX_MOE_COMBOS = [
53-
(1, 128, 128),
56+
# (1, 128, 128),
5457
(2, 128, 512),
5558
(3, 1024, 2048),
5659
(32, 128, 1024),
@@ -202,7 +205,7 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
202205

203206

204207
def dummy_work(a: torch.Tensor) -> torch.Tensor:
205-
return a # * 1.5
208+
return a * 1.1
206209

207210

208211
def pplx_prepare_finalize(
@@ -270,6 +273,13 @@ def pplx_prepare_finalize(
270273
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
271274
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
272275

276+
out = torch.full(
277+
(max_num_tokens, hidden_dim),
278+
torch.nan,
279+
dtype=a.dtype,
280+
device=device,
281+
)
282+
273283
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
274284
a_chunk,
275285
None,
@@ -287,16 +297,10 @@ def pplx_prepare_finalize(
287297
),
288298
)
289299

290-
# Do some fake work
291-
#print(f"INTER {b_a.shape} {b_a_scale.shape if b_a_scale is not None else None}")
292-
b_a = dummy_work(dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
300+
#print(f"B_A_SCALE = {b_a.shape}, {b_a_scale.shape if b_a_scale is not None else None}, {per_act_token_quant} {block_shape}, {a_chunk.shape}")
301+
# TOOD: shouldn't need batched_dequant
293302

294-
out = torch.full(
295-
(max_num_tokens, hidden_dim),
296-
torch.nan,
297-
dtype=a.dtype,
298-
device=device,
299-
)
303+
b_a = dummy_work(batched_dequant(b_a, b_a_scale, block_shape, per_act_token_quant, a.dtype))
300304

301305
prepare_finalize.finalize(
302306
out,
@@ -338,49 +342,34 @@ def _pplx_prepare_finalize(
338342
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
339343
group_name = cpu_group.group_name
340344

341-
#device = pgi.device
342-
343345
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
344346
m, k = a.shape
345347

346-
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0) #.to(device)
348+
a_rep = torch.repeat_interleave(dummy_work(a), topk, dim=0)
347349

348-
if True:
349-
torch_output = (a_rep.view(m, topk, k) *
350-
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1)
351-
else:
352-
import vllm._custom_ops as ops
353-
a_rep = a_rep.view(m, topk, k)
354-
a_rep.mul_(topk_weight.view(m, topk, 1).to(a_rep.dtype))
355-
torch_output = torch.empty_like(a)
356-
ops.moe_sum(a_rep, torch_output)
350+
torch_output = (a_rep.view(m, topk, k) *
351+
topk_weight.view(m, topk, 1).to(a_rep.dtype)).sum(dim=1)
357352

358353
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
359354
num_experts, quant_dtype, block_shape,
360355
per_act_token_quant, group_name)
361356

362-
torch_output = chunk_by_rank(torch_output, pgi.rank,
363-
pgi.world_size).to(pplx_output.device)
364-
365-
#torch.set_printoptions(profile="full")
366-
#print(f"PPLX {pplx_output.shape}\n{pplx_output.shape}")
367-
#print(f"TORCH {torch_output.shape}\n{torch_output.shape}")
357+
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pgi.device)
368358

369359
torch.testing.assert_close(pplx_output, torch_output, atol=3e-2, rtol=3e-2)
370360

371361
if use_internode:
372362
nvshmem_finalize()
373363

374364

375-
# TODO (bnell): this test point does not work for odd M due to how the test is
376-
# written, not due to limitations of the pplx kernels. The pplx_moe
377-
# test below is able to deal with odd M.
365+
# TODO (bnell): this test point does not work for M==1 due to how the test
366+
# is written, not due to limitations of the pplx kernels.
378367
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
379368
@pytest.mark.parametrize("e", NUM_EXPERTS)
380369
@pytest.mark.parametrize("topk", TOP_KS)
381370
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
382371
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
383-
@pytest.mark.parametrize("per_act_token_quant", [False])
372+
@pytest.mark.parametrize("per_act_token_quant", [False, True])
384373
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
385374
@pytest.mark.parametrize("use_internode", [False])
386375
@requires_pplx
@@ -414,8 +403,6 @@ def test_pplx_prepare_finalize(
414403
world_size, dp_size = world_dp_size
415404
device = "cuda"
416405

417-
#print(f"MNK = {mnk}")
418-
419406
a = torch.randn((m, k), device=device, dtype=act_dtype) / 10
420407
score = torch.randn((m, e), device=device, dtype=act_dtype)
421408

@@ -508,10 +495,13 @@ def pplx_moe(
508495
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
509496
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
510497

511-
# TODO scale chunk function
512498
if w1_scale is not None:
513-
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
514-
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
499+
if not per_act_token_quant:
500+
w1_scale_chunk = w1_scale
501+
w2_scale_chunk = w2_scale
502+
else:
503+
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
504+
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
515505
else:
516506
w1_scale_chunk = None
517507
w2_scale_chunk = None
@@ -562,48 +552,6 @@ def pplx_moe(
562552
return out
563553

564554

565-
def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
566-
assert torch.cuda.current_device() == pgi.local_rank
567-
568-
num_experts = w1.shape[0]
569-
device = pgi.device
570-
rank = pgi.rank
571-
world_size = pgi.world_size
572-
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
573-
574-
prepare_finalize = BatchedPrepareAndFinalize(
575-
max_num_tokens=max_num_tokens,
576-
world_size=world_size,
577-
dp_size=dp_size,
578-
rank=rank,
579-
)
580-
581-
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
582-
world_size=1,
583-
dp_size=1)
584-
585-
fused_experts = FusedMoEModularKernel(
586-
prepare_finalize,
587-
experts,
588-
)
589-
590-
# Note: workers with the same dp_rank must use the exact same inputs.
591-
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
592-
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
593-
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
594-
595-
out = fused_experts(
596-
a_chunk,
597-
# Chunking weights like this only works for batched format
598-
chunk_by_rank(w1, rank, world_size).to(device),
599-
chunk_by_rank(w2, rank, world_size).to(device),
600-
chunk_topk_weight,
601-
chunk_topk_ids,
602-
global_num_experts=num_experts)
603-
604-
return out
605-
606-
607555
def _pplx_moe(
608556
pgi: ProcessGroupInfo,
609557
dp_size: int,
@@ -654,18 +602,22 @@ def _pplx_moe(
654602
quant_dtype=qtype,
655603
per_act_token_quant=per_act_token_quant,
656604
block_shape=block_shape)
605+
657606
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
658607
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
659608
qtype, per_act_token_quant, block_shape)
660-
# TODO (bnell): fix + re-enable
661-
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
662-
# topk_ids)
663609

664-
torch_output = chunk_by_rank(torch_output, pgi.rank,
665-
pgi.world_size).to(pplx_output.device)
610+
# all reduce on pplx?
611+
#torch.distributed.all_reduce(pplx_output)
612+
613+
batched_output = naive_batched_moe(a, w1, w2, topk_weight,
614+
topk_ids, w1_s, w2_s, qtype, per_act_token_quant, block_shape)
615+
616+
chunked_torch_output = chunk_by_rank(torch_output, pgi.rank,
617+
pgi.world_size).to(pplx_output.device)
666618

667-
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
668-
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
619+
torch.testing.assert_close(pplx_output, chunked_torch_output, atol=3e-2, rtol=3e-2)
620+
#torch.testing.assert_close(batched_output, torch_output, atol=3e-2, rtol=3e-2)
669621

670622
if use_internode:
671623
nvshmem_finalize()

tests/kernels/moe/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def make_quantized_test_activations(
172172
return a, a_q, a_scale
173173

174174

175+
# TODO: split this into two calls to a single function
175176
def make_test_weights(
176177
e: int,
177178
n: int,

tests/kernels/quant_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def per_block_cast_to_fp8(
237237

238238

239239
def dequant(
240-
t: torch.Tensor, scale:
241-
Optional[torch.Tensor],
240+
t: torch.Tensor,
241+
scale: Optional[torch.Tensor],
242242
block_shape: Optional[list[int]],
243243
per_act_token_quant: bool,
244244
out_dtype: Optional[torch.dtype] = torch.float32,
@@ -253,6 +253,23 @@ def dequant(
253253
return t.to(out_dtype)
254254

255255

256+
def batched_dequant(
257+
t: torch.Tensor, scale:
258+
Optional[torch.Tensor],
259+
block_shape: Optional[list[int]],
260+
per_act_token_quant: bool,
261+
out_dtype: Optional[torch.dtype] = torch.float32,
262+
) -> torch.Tensor:
263+
if scale is not None:
264+
assert t.shape[0] == scale.shape[0]
265+
out = torch.empty_like(t, dtype=out_dtype)
266+
for e in range(t.shape[0]):
267+
out[e] = dequant(t[e], scale[e], block_shape, per_act_token_quant, out_dtype)
268+
return out
269+
270+
return t.to(out_dtype)
271+
272+
256273
def native_batched_masked_quant_matmul(
257274
A: torch.Tensor,
258275
B: torch.Tensor,

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def make(
129129
use_int8_w8a8=use_int8_w8a8,
130130
use_int8_w8a16=use_int8_w8a16,
131131
use_int4_w4a16=use_int4_w4a16)
132-
133132
return FusedMoEQuantConfig(
134133
quant_dtype,
135134
per_act_token_quant,
@@ -300,6 +299,8 @@ def __post_init__(self):
300299
logger.debug("Using FusedMoEConfig::max_num_tokens=%d",
301300
self.max_num_tokens)
302301

302+
assert self.max_num_tokens > 0
303+
303304
@property
304305
def quant_dtype(self) -> Optional[torch.dtype]:
305306
if self.quant_config is not None:

0 commit comments

Comments
 (0)