Skip to content

Commit fdadb6f

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Bugfix] Fused MoE Modular Kernel chunking loop (#20392)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 41060c6 commit fdadb6f

File tree

4 files changed

+404
-107
lines changed

4 files changed

+404
-107
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Tests compute_expert_num_tokens kernels
5+
"""
6+
7+
import dataclasses
8+
from typing import Optional
9+
10+
import pytest
11+
import torch
12+
13+
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
14+
15+
16+
@dataclasses.dataclass
17+
class TestTensors:
18+
19+
topk_ids: torch.Tensor
20+
expert_map: Optional[torch.Tensor] = None
21+
22+
def to_device(self, device: str):
23+
self.topk_ids = self.topk_ids.to(device=device)
24+
if self.expert_map is not None:
25+
self.expert_map = self.expert_map.to(device=device)
26+
27+
@staticmethod
28+
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
29+
topk_ids_dtype: torch.dtype) -> "TestTensors":
30+
31+
# make topk ids
32+
topk_ids = torch.empty((num_tokens, num_topk),
33+
device=device,
34+
dtype=torch.int64)
35+
for x in range(num_tokens):
36+
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
37+
topk_ids = topk_ids.to(dtype=torch.int64)
38+
return TestTensors(topk_ids=topk_ids)
39+
40+
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
41+
num_local_experts: int, device: str):
42+
# make an expert map
43+
expert_map = torch.empty((num_global_experts),
44+
device=device,
45+
dtype=torch.int32)
46+
expert_map.fill_(-1)
47+
s = ep_rank * num_local_experts
48+
e = s + num_local_experts
49+
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
50+
device=device)
51+
52+
return TestTensors(topk_ids=self.topk_ids.clone(),
53+
expert_map=expert_map)
54+
55+
56+
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
57+
# do the reference in cpu
58+
tt.to_device("cpu")
59+
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
60+
61+
for eid, count in zip(expert_ids, counts):
62+
if eid != -1 and tt.expert_map is not None:
63+
eid = tt.expert_map[eid]
64+
65+
if eid == -1:
66+
continue
67+
68+
expert_num_tokens[eid] += count
69+
70+
71+
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
72+
num_experts: int, ep_size: int,
73+
topk_ids_dtype: torch.dtype):
74+
75+
assert num_topk <= num_experts
76+
77+
tt = TestTensors.make(num_tokens,
78+
num_topk,
79+
num_experts,
80+
topk_ids_dtype=topk_ids_dtype,
81+
device="cpu")
82+
83+
num_global_experts = num_experts
84+
assert num_global_experts % ep_size == 0
85+
num_local_experts = num_global_experts // ep_size
86+
for ep_rank in range(ep_size):
87+
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
88+
num_local_experts, "cpu")
89+
90+
ref_expert_num_tokens = torch.zeros((num_local_experts),
91+
device="cpu",
92+
dtype=torch.int32)
93+
ref_impl(tt_rank, ref_expert_num_tokens)
94+
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
95+
96+
tt_rank.to_device("cuda")
97+
# Test with expert_map
98+
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
99+
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
100+
101+
# Test without expert map
102+
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
103+
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
104+
topk_ids, num_local_experts, expert_map=None)
105+
106+
torch.testing.assert_close(ref_expert_num_tokens,
107+
triton_expert_num_tokens_w_emap,
108+
atol=0,
109+
rtol=0)
110+
torch.testing.assert_close(ref_expert_num_tokens,
111+
triton_expert_num_tokens_wo_emap,
112+
atol=0,
113+
rtol=0)
114+
115+
116+
@pytest.mark.parametrize(
117+
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
118+
@pytest.mark.parametrize("num_topk", [2, 6, 8])
119+
@pytest.mark.parametrize("num_experts", [64])
120+
@pytest.mark.parametrize("ep_size", [1, 2, 4])
121+
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
122+
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
123+
num_experts: int, ep_size: int,
124+
topk_ids_dtype: torch.dtype):
125+
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
126+
ep_size, topk_ids_dtype)
127+
128+
129+
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
130+
@pytest.mark.parametrize("num_experts", [32])
131+
@pytest.mark.parametrize("ep_size", [2])
132+
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
133+
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
134+
ep_size: int,
135+
topk_ids_dtype: torch.dtype):
136+
do_test_compute_expert_num_tokens(num_tokens=numel,
137+
num_topk=1,
138+
num_experts=num_experts,
139+
ep_size=ep_size,
140+
topk_ids_dtype=topk_ids_dtype)

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def workspace_shapes(
9898
M_sum = round_up(M_sum, block_m)
9999
workspace1 = (M_sum, max(N * 2, K))
100100
workspace2 = (M_sum, max(N, K))
101-
output = (M * topk, K)
101+
output = (M, topk, K)
102102
return (workspace1, workspace2, output, a.dtype)
103103

104104
def apply(
@@ -172,7 +172,7 @@ def apply(
172172
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
173173
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
174174

175-
torch.index_select(mm2_out, 0, inv_perm, out=output)
175+
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
176176

177177

178178
def deep_gemm_moe_fp8(

0 commit comments

Comments
 (0)