Skip to content

Commit 52f935c

Browse files
committed
move deepep_utils -> parallel_utils
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent c88f65d commit 52f935c

File tree

5 files changed

+203
-6
lines changed

5 files changed

+203
-6
lines changed

tests/kernels/moe/parallel_utils.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
DeepEP test utilities
4+
"""
5+
import dataclasses
6+
import importlib
7+
import os
8+
import socket
9+
import traceback
10+
from contextlib import closing
11+
from typing import Callable, Optional
12+
13+
import torch
14+
from torch.distributed import ProcessGroup
15+
from torch.multiprocessing import (
16+
spawn) # pyright: ignore[reportPrivateImportUsage]
17+
from typing_extensions import Concatenate, ParamSpec
18+
19+
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
20+
if has_deep_ep:
21+
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
22+
DeepEPHTPrepareAndFinalize)
23+
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
24+
DeepEPLLPrepareAndFinalize)
25+
26+
## Parallel Processes Utils
27+
28+
P = ParamSpec("P")
29+
30+
31+
@dataclasses.dataclass
32+
class ProcessGroupInfo:
33+
world_size: int
34+
world_local_size: int
35+
rank: int
36+
node_rank: int
37+
local_rank: int
38+
device: torch.device
39+
40+
41+
def _worker_parallel_launch(
42+
local_rank: int,
43+
world_size: int,
44+
world_local_size: int,
45+
node_rank: int,
46+
init_method: str,
47+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
48+
*args: P.args,
49+
**kwargs: P.kwargs,
50+
) -> None:
51+
rank = node_rank * world_local_size + local_rank
52+
torch.cuda.set_device(local_rank)
53+
device = torch.device("cuda", local_rank)
54+
torch.distributed.init_process_group(
55+
backend="cpu:gloo,cuda:nccl",
56+
init_method=init_method,
57+
rank=rank,
58+
world_size=world_size,
59+
device_id=device,
60+
)
61+
barrier = torch.tensor([rank], device=device)
62+
torch.distributed.all_reduce(barrier)
63+
64+
try:
65+
worker(
66+
ProcessGroupInfo(
67+
world_size=world_size,
68+
world_local_size=world_local_size,
69+
rank=rank,
70+
node_rank=node_rank,
71+
local_rank=local_rank,
72+
device=device,
73+
),
74+
*args,
75+
**kwargs,
76+
)
77+
except Exception as ex:
78+
print(ex)
79+
traceback.print_exc()
80+
raise
81+
finally:
82+
torch.distributed.destroy_process_group()
83+
84+
85+
def find_free_port():
86+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
87+
s.bind(('', 0))
88+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
89+
return s.getsockname()[1]
90+
91+
92+
def parallel_launch(
93+
world_size: int,
94+
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
95+
*args: P.args,
96+
**kwargs: P.kwargs,
97+
) -> None:
98+
assert not kwargs
99+
spawn(
100+
_worker_parallel_launch,
101+
args=(
102+
world_size,
103+
world_size,
104+
0,
105+
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{find_free_port()}",
106+
worker,
107+
) + args,
108+
nprocs=world_size,
109+
join=True,
110+
)
111+
112+
113+
## DeepEP specific utils
114+
115+
116+
@dataclasses.dataclass
117+
class DeepEPHTArgs:
118+
num_local_experts: int
119+
120+
121+
@dataclasses.dataclass
122+
class DeepEPLLArgs:
123+
max_tokens_per_rank: int
124+
hidden_size: int
125+
num_experts: int
126+
use_fp8_dispatch: bool
127+
128+
129+
def make_deepep_ht_a2a(pg: ProcessGroup,
130+
pgi: ProcessGroupInfo,
131+
dp_size: int,
132+
ht_args: DeepEPHTArgs,
133+
q_dtype: Optional[torch.dtype] = None,
134+
block_shape: Optional[list[int]] = None):
135+
136+
import deep_ep
137+
138+
# high throughput a2a
139+
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
140+
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
141+
buffer = deep_ep.Buffer(group=pg,
142+
num_nvl_bytes=num_nvl_bytes,
143+
num_rdma_bytes=num_rdma_bytes,
144+
low_latency_mode=low_latency_mode,
145+
num_qps_per_rank=num_qps_per_rank)
146+
return DeepEPHTPrepareAndFinalize(buffer=buffer,
147+
world_size=pgi.world_size,
148+
rank=pgi.rank,
149+
dp_size=dp_size,
150+
rank_expert_offset=pgi.rank *
151+
ht_args.num_local_experts)
152+
153+
154+
def make_deepep_ll_a2a(pg: ProcessGroup,
155+
pgi: ProcessGroupInfo,
156+
dp_size: int,
157+
deepep_ll_args: DeepEPLLArgs,
158+
q_dtype: Optional[torch.dtype] = None,
159+
block_shape: Optional[list[int]] = None):
160+
161+
import deep_ep
162+
163+
# low-latency a2a
164+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
165+
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
166+
pgi.world_size, deepep_ll_args.num_experts)
167+
168+
buffer = deep_ep.Buffer(group=pg,
169+
num_rdma_bytes=num_rdma_bytes,
170+
low_latency_mode=True,
171+
num_qps_per_rank=deepep_ll_args.num_experts //
172+
pgi.world_size)
173+
174+
return DeepEPLLPrepareAndFinalize(
175+
buffer=buffer,
176+
world_size=pgi.world_size,
177+
dp_size=dp_size,
178+
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
179+
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
180+
)
181+
182+
183+
def make_deepep_a2a(pg: ProcessGroup,
184+
pgi: ProcessGroupInfo,
185+
dp_size: int,
186+
deepep_ht_args: Optional[DeepEPHTArgs],
187+
deepep_ll_args: Optional[DeepEPLLArgs],
188+
q_dtype: Optional[torch.dtype] = None,
189+
block_shape: Optional[list[int]] = None):
190+
if deepep_ht_args is not None:
191+
assert deepep_ll_args is None
192+
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
193+
block_shape)
194+
195+
assert deepep_ll_args is not None
196+
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
197+
block_shape)

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
per_token_group_quant_fp8)
2323
from vllm.platforms import current_platform
2424

25-
from .deepep_utils import ProcessGroupInfo, parallel_launch
25+
from .parallel_utils import ProcessGroupInfo, parallel_launch
2626
from .utils import make_test_weights
2727

2828
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
@@ -34,7 +34,7 @@
3434
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
3535
DeepEPLLPrepareAndFinalize)
3636

37-
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
37+
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
3838

3939
if has_deep_gemm:
4040
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (

tests/kernels/moe/test_deepep_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
per_token_group_quant_fp8)
2424
from vllm.platforms import current_platform
2525

26-
from .utils import ProcessGroupInfo, parallel_launch
26+
from .parallel_utils import ProcessGroupInfo, parallel_launch
2727

2828
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
2929

@@ -33,7 +33,7 @@
3333
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
3434
DeepEPLLPrepareAndFinalize)
3535

36-
from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
36+
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
3737

3838
requires_deep_ep = pytest.mark.skipif(
3939
not has_deep_ep,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
FusedMoEModularKernel)
1616
from vllm.platforms import current_platform
1717

18-
from .deepep_utils import ProcessGroupInfo, parallel_launch
18+
from .parallel_utils import ProcessGroupInfo, parallel_launch
1919

2020
try:
2121
from pplx_kernels import AllToAll

tests/kernels/moe/test_pplx_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.platforms import current_platform
3232
from vllm.utils import round_up
3333

34-
from .deepep_utils import ProcessGroupInfo, parallel_launch
34+
from .parallel_utils import ProcessGroupInfo, parallel_launch
3535

3636
requires_pplx = pytest.mark.skipif(
3737
not has_pplx,

0 commit comments

Comments
 (0)