Skip to content

Commit a6b0a59

Browse files
bnellnmElizaWszola
authored andcommitted
[Kernels] MoE refactor (vllm-project#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
1 parent b419f66 commit a6b0a59

36 files changed

+2713
-1599
lines changed

benchmarks/kernels/benchmark_grouped_gemm_cutlass.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def run_cutlass_moe(
113113
w2_scale: torch.Tensor,
114114
topk_weights: torch.Tensor,
115115
topk_ids: torch.Tensor,
116+
per_act_token: bool,
116117
num_repeats: int,
117118
):
118119
for _ in range(num_repeats):
@@ -124,7 +125,8 @@ def run_cutlass_moe(
124125
topk_ids,
125126
w1_scale,
126127
w2_scale,
127-
a1_scale=a_scale,
128+
per_act_token,
129+
a1_scale=None,
128130
)
129131

130132
def run_cutlass_from_graph(
@@ -148,7 +150,8 @@ def run_cutlass_from_graph(
148150
topk_ids,
149151
w1_scale,
150152
w2_scale,
151-
a1_scale=a_scale,
153+
per_act_token,
154+
a1_scale=None,
152155
)
153156

154157
def run_triton_from_graph(
@@ -227,6 +230,7 @@ def replay_graph(graph, num_repeats):
227230
"w2_q": w2_q,
228231
"w1_scale": w1_scale,
229232
"w2_scale": w2_scale,
233+
"per_act_token": per_act_token,
230234
# cuda graph params
231235
"cutlass_graph": cutlass_graph,
232236
"triton_graph": triton_graph,
@@ -287,12 +291,13 @@ def replay_graph(graph, num_repeats):
287291
w2_scale,
288292
topk_weights,
289293
topk_ids,
294+
per_act_token,
290295
num_warmup,
291296
)
292297

293298
results.append(
294299
benchmark.Timer(
295-
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
300+
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
296301
globals=globals,
297302
label=label,
298303
sub_label=sub_label,

tests/kernels/moe/parallel_utils.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
DeepEP test utilities
4+
"""
5+
import dataclasses
6+
import importlib
7+
import os
8+
import traceback
9+
from typing import Callable, Optional
10+
11+
import torch
12+
from torch.distributed import ProcessGroup
13+
from torch.multiprocessing import (
14+
spawn) # pyright: ignore[reportPrivateImportUsage]
15+
from typing_extensions import Concatenate, ParamSpec
16+
17+
from vllm.utils import get_open_port
18+
19+
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
20+
if has_deep_ep:
21+
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
22+
DeepEPHTPrepareAndFinalize)
23+
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
24+
DeepEPLLPrepareAndFinalize)
25+
26+
## Parallel Processes Utils
27+
28+
P = ParamSpec("P")
29+
30+
31+
@dataclasses.dataclass
32+
class ProcessGroupInfo:
33+
world_size: int
34+
world_local_size: int
35+
rank: int
36+
node_rank: int
37+
local_rank: int
38+
device: torch.device
39+
40+
41+
def _worker_parallel_launch(
42+
local_rank: int,
43+
world_size: int,
44+
world_local_size: int,
45+
node_rank: int,
46+
init_method: str,
47+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
48+
*args: P.args,
49+
**kwargs: P.kwargs,
50+
) -> None:
51+
rank = node_rank * world_local_size + local_rank
52+
torch.cuda.set_device(local_rank)
53+
device = torch.device("cuda", local_rank)
54+
torch.distributed.init_process_group(
55+
backend="cpu:gloo,cuda:nccl",
56+
init_method=init_method,
57+
rank=rank,
58+
world_size=world_size,
59+
device_id=device,
60+
)
61+
barrier = torch.tensor([rank], device=device)
62+
torch.distributed.all_reduce(barrier)
63+
64+
try:
65+
worker(
66+
ProcessGroupInfo(
67+
world_size=world_size,
68+
world_local_size=world_local_size,
69+
rank=rank,
70+
node_rank=node_rank,
71+
local_rank=local_rank,
72+
device=device,
73+
),
74+
*args,
75+
**kwargs,
76+
)
77+
except Exception as ex:
78+
print(ex)
79+
traceback.print_exc()
80+
raise
81+
finally:
82+
torch.distributed.destroy_process_group()
83+
84+
85+
def parallel_launch(
86+
world_size: int,
87+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
88+
*args: P.args,
89+
**kwargs: P.kwargs,
90+
) -> None:
91+
assert not kwargs
92+
spawn(
93+
_worker_parallel_launch,
94+
args=(
95+
world_size,
96+
world_size,
97+
0,
98+
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
99+
worker,
100+
) + args,
101+
nprocs=world_size,
102+
join=True,
103+
)
104+
105+
106+
## DeepEP specific utils
107+
108+
109+
@dataclasses.dataclass
110+
class DeepEPHTArgs:
111+
num_local_experts: int
112+
113+
114+
@dataclasses.dataclass
115+
class DeepEPLLArgs:
116+
max_tokens_per_rank: int
117+
hidden_size: int
118+
num_experts: int
119+
use_fp8_dispatch: bool
120+
121+
122+
def make_deepep_ht_a2a(pg: ProcessGroup,
123+
pgi: ProcessGroupInfo,
124+
dp_size: int,
125+
ht_args: DeepEPHTArgs,
126+
q_dtype: Optional[torch.dtype] = None,
127+
block_shape: Optional[list[int]] = None):
128+
129+
import deep_ep
130+
131+
# high throughput a2a
132+
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
133+
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
134+
buffer = deep_ep.Buffer(group=pg,
135+
num_nvl_bytes=num_nvl_bytes,
136+
num_rdma_bytes=num_rdma_bytes,
137+
low_latency_mode=low_latency_mode,
138+
num_qps_per_rank=num_qps_per_rank)
139+
return DeepEPHTPrepareAndFinalize(buffer=buffer,
140+
world_size=pgi.world_size,
141+
rank=pgi.rank,
142+
dp_size=dp_size,
143+
rank_expert_offset=pgi.rank *
144+
ht_args.num_local_experts)
145+
146+
147+
def make_deepep_ll_a2a(pg: ProcessGroup,
148+
pgi: ProcessGroupInfo,
149+
dp_size: int,
150+
deepep_ll_args: DeepEPLLArgs,
151+
q_dtype: Optional[torch.dtype] = None,
152+
block_shape: Optional[list[int]] = None):
153+
154+
import deep_ep
155+
156+
# low-latency a2a
157+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
158+
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
159+
pgi.world_size, deepep_ll_args.num_experts)
160+
161+
buffer = deep_ep.Buffer(group=pg,
162+
num_rdma_bytes=num_rdma_bytes,
163+
low_latency_mode=True,
164+
num_qps_per_rank=deepep_ll_args.num_experts //
165+
pgi.world_size)
166+
167+
return DeepEPLLPrepareAndFinalize(
168+
buffer=buffer,
169+
world_size=pgi.world_size,
170+
dp_size=dp_size,
171+
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
172+
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
173+
)
174+
175+
176+
def make_deepep_a2a(pg: ProcessGroup,
177+
pgi: ProcessGroupInfo,
178+
dp_size: int,
179+
deepep_ht_args: Optional[DeepEPHTArgs],
180+
deepep_ll_args: Optional[DeepEPLLArgs],
181+
q_dtype: Optional[torch.dtype] = None,
182+
block_shape: Optional[list[int]] = None):
183+
if deepep_ht_args is not None:
184+
assert deepep_ll_args is None
185+
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
186+
block_shape)
187+
188+
assert deepep_ll_args is not None
189+
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
190+
block_shape)

0 commit comments

Comments
 (0)