Skip to content

Commit fa98d77

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Kernel] DeepEP dispatch-combine kernel integration (#18434)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 01eee40 commit fa98d77

23 files changed

+1952
-124
lines changed

csrc/moe/topk_softmax_kernels.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,8 @@ void topk_softmax(
516516
topk,
517517
stream);
518518
}
519-
else
519+
else if (topk_indices.scalar_type() == at::ScalarType::UInt32)
520520
{
521-
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
522521
vllm::moe::topkGatingSoftmaxKernelLauncher(
523522
gating_output.data_ptr<float>(),
524523
topk_weights.data_ptr<float>(),
@@ -530,4 +529,17 @@ void topk_softmax(
530529
topk,
531530
stream);
532531
}
532+
else {
533+
assert(topk_indices.scalar_type() == at::ScalarType::Int64);
534+
vllm::moe::topkGatingSoftmaxKernelLauncher(
535+
gating_output.data_ptr<float>(),
536+
topk_weights.data_ptr<float>(),
537+
topk_indices.data_ptr<int64_t>(),
538+
token_expert_indices.data_ptr<int>(),
539+
softmax_workspace.data_ptr<float>(),
540+
num_tokens,
541+
num_experts,
542+
topk,
543+
stream);
544+
}
533545
}

tests/kernels/moe/__init__.py

Whitespace-only changes.

tests/kernels/moe/deepep_utils.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
DeepEP test utilities
4+
"""
5+
import dataclasses
6+
import importlib
7+
import traceback
8+
from typing import Callable, Optional
9+
10+
import torch
11+
from torch.distributed import ProcessGroup
12+
from torch.multiprocessing import (
13+
spawn) # pyright: ignore[reportPrivateImportUsage]
14+
from typing_extensions import Concatenate, ParamSpec
15+
16+
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
17+
if has_deep_ep:
18+
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
19+
DeepEPHTPrepareAndFinalize)
20+
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
21+
DeepEPLLPrepareAndFinalize)
22+
23+
## Parallel Processes Utils
24+
25+
P = ParamSpec("P")
26+
27+
28+
@dataclasses.dataclass
29+
class ProcessGroupInfo:
30+
world_size: int
31+
world_local_size: int
32+
rank: int
33+
node_rank: int
34+
local_rank: int
35+
device: torch.device
36+
37+
38+
def _worker_parallel_launch(
39+
local_rank: int,
40+
world_size: int,
41+
world_local_size: int,
42+
node_rank: int,
43+
init_method: str,
44+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
45+
*args: P.args,
46+
**kwargs: P.kwargs,
47+
) -> None:
48+
rank = node_rank * world_local_size + local_rank
49+
torch.cuda.set_device(local_rank)
50+
device = torch.device("cuda", local_rank)
51+
torch.distributed.init_process_group(
52+
backend="cpu:gloo,cuda:nccl",
53+
init_method=init_method,
54+
rank=rank,
55+
world_size=world_size,
56+
device_id=device,
57+
)
58+
barrier = torch.tensor([rank], device=device)
59+
torch.distributed.all_reduce(barrier)
60+
61+
try:
62+
worker(
63+
ProcessGroupInfo(
64+
world_size=world_size,
65+
world_local_size=world_local_size,
66+
rank=rank,
67+
node_rank=node_rank,
68+
local_rank=local_rank,
69+
device=device,
70+
),
71+
*args,
72+
**kwargs,
73+
)
74+
except Exception as ex:
75+
print(ex)
76+
traceback.print_exc()
77+
raise
78+
finally:
79+
torch.distributed.destroy_process_group()
80+
81+
82+
def parallel_launch(
83+
world_size: int,
84+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
85+
*args: P.args,
86+
**kwargs: P.kwargs,
87+
) -> None:
88+
assert not kwargs
89+
spawn(
90+
_worker_parallel_launch,
91+
args=(
92+
world_size,
93+
world_size,
94+
0,
95+
"tcp://localhost:29500",
96+
worker,
97+
) + args,
98+
nprocs=world_size,
99+
join=True,
100+
)
101+
102+
103+
## DeepEP specific utils
104+
105+
106+
@dataclasses.dataclass
107+
class DeepEPHTArgs:
108+
num_local_experts: int
109+
110+
111+
@dataclasses.dataclass
112+
class DeepEPLLArgs:
113+
max_tokens_per_rank: int
114+
hidden_size: int
115+
num_experts: int
116+
use_fp8_dispatch: bool
117+
118+
119+
def make_deepep_ht_a2a(pg: ProcessGroup,
120+
pgi: ProcessGroupInfo,
121+
dp_size: int,
122+
ht_args: DeepEPHTArgs,
123+
q_dtype: Optional[torch.dtype] = None,
124+
block_shape: Optional[list[int]] = None):
125+
126+
import deep_ep
127+
128+
# high throughput a2a
129+
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
130+
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
131+
buffer = deep_ep.Buffer(group=pg,
132+
num_nvl_bytes=num_nvl_bytes,
133+
num_rdma_bytes=num_rdma_bytes,
134+
low_latency_mode=low_latency_mode,
135+
num_qps_per_rank=num_qps_per_rank)
136+
return DeepEPHTPrepareAndFinalize(buffer=buffer,
137+
world_size=pgi.world_size,
138+
rank=pgi.rank,
139+
dp_size=dp_size,
140+
rank_expert_offset=pgi.rank *
141+
ht_args.num_local_experts,
142+
quant_dtype=q_dtype,
143+
block_shape=block_shape)
144+
145+
146+
def make_deepep_ll_a2a(pg: ProcessGroup,
147+
pgi: ProcessGroupInfo,
148+
dp_size: int,
149+
deepep_ll_args: DeepEPLLArgs,
150+
q_dtype: Optional[torch.dtype] = None,
151+
block_shape: Optional[list[int]] = None):
152+
153+
import deep_ep
154+
155+
# low-latency a2a
156+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
157+
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
158+
pgi.world_size, deepep_ll_args.num_experts)
159+
160+
buffer = deep_ep.Buffer(group=pg,
161+
num_rdma_bytes=num_rdma_bytes,
162+
low_latency_mode=True,
163+
num_qps_per_rank=deepep_ll_args.num_experts //
164+
pgi.world_size)
165+
return DeepEPLLPrepareAndFinalize(
166+
buffer=buffer,
167+
world_size=pgi.world_size,
168+
dp_size=dp_size,
169+
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
170+
quant_dtype=q_dtype,
171+
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
172+
)
173+
174+
175+
def make_deepep_a2a(pg: ProcessGroup,
176+
pgi: ProcessGroupInfo,
177+
dp_size: int,
178+
deepep_ht_args: Optional[DeepEPHTArgs],
179+
deepep_ll_args: Optional[DeepEPLLArgs],
180+
q_dtype: Optional[torch.dtype] = None,
181+
block_shape: Optional[list[int]] = None):
182+
if deepep_ht_args is not None:
183+
assert deepep_ll_args is None
184+
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
185+
block_shape)
186+
187+
assert deepep_ll_args is not None
188+
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype)

0 commit comments

Comments
 (0)