1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
+ from typing import Optional
5
+
4
6
import pytest
5
7
import torch
6
8
7
- from tests .pplx_utils import ProcessGroupInfo , parallel_launch
8
9
from vllm import _custom_ops as ops
9
10
from vllm .config import VllmConfig , set_current_vllm_config
10
11
from vllm .model_executor .layers .activation import SiluAndMul
14
15
FusedMoEModularKernel )
15
16
from vllm .platforms import current_platform
16
17
18
+ from .deepep_utils import ProcessGroupInfo , parallel_launch
19
+
17
20
try :
18
21
from pplx_kernels import AllToAll
19
22
from pplx_kernels .nvshmem import (nvshmem_alloc_empty_unique_id ,
@@ -64,6 +67,7 @@ def pplx_cutlass_moe(
64
67
out_dtype ,
65
68
per_act_token : bool ,
66
69
per_out_ch : bool ,
70
+ group_name : Optional [str ],
67
71
):
68
72
from vllm .model_executor .layers .fused_moe .pplx_prepare_finalize import (
69
73
PplxPrepareAndFinalize )
@@ -84,7 +88,7 @@ def pplx_cutlass_moe(
84
88
else :
85
89
scale_elems = (hidden_dim + block_size - 1 ) // block_size
86
90
87
- ata = AllToAll . internode (
91
+ args = dict (
88
92
max_num_tokens = max_num_tokens ,
89
93
num_experts = num_experts ,
90
94
experts_per_token = topk ,
@@ -96,6 +100,12 @@ def pplx_cutlass_moe(
96
100
hidden_dim_scale_bytes = scale_elems * torch .float32 .itemsize ,
97
101
)
98
102
103
+ if group_name is None :
104
+ ata = AllToAll .internode (** args )
105
+ else :
106
+ args ["group_name" ] = group_name
107
+ ata = AllToAll .intranode (** args )
108
+
99
109
w1 = w1 .to (device )
100
110
w2 = w2 .to (device )
101
111
w1_scale = w1_scale .to (device )
@@ -113,7 +123,10 @@ def pplx_cutlass_moe(
113
123
)
114
124
115
125
experts = CutlassExpertsFp8 ((num_experts + world_size - 1 ) // world_size ,
116
- out_dtype , per_act_token , per_out_ch )
126
+ out_dtype ,
127
+ per_act_token ,
128
+ per_out_ch ,
129
+ use_batched_format = True )
117
130
118
131
fused_cutlass_experts = FusedMoEModularKernel (
119
132
prepare_finalize ,
@@ -184,19 +197,25 @@ def _pplx_moe(
184
197
w2_full : torch .Tensor ,
185
198
per_act_token : bool ,
186
199
per_out_ch : bool ,
200
+ use_internode : bool ,
187
201
):
188
- uid = nvshmem_get_unique_id (
189
- ) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
190
- torch .distributed .broadcast (uid , src = 0 )
191
- nvshmem_init (uid , pgi .rank , pgi .world_size )
202
+ if use_internode :
203
+ uid = nvshmem_get_unique_id (
204
+ ) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
205
+ torch .distributed .broadcast (uid , src = 0 )
206
+ nvshmem_init (uid , pgi .rank , pgi .world_size )
207
+ else :
208
+ group_ranks = list (range (pgi .world_size ))
209
+ cpu_group = torch .distributed .new_group (group_ranks , backend = "gloo" )
210
+ group_name = cpu_group .group_name
192
211
193
212
with set_current_vllm_config (vllm_config ):
194
213
torch_output = torch_moe2 (a_full , w1_full , w2_full , topk_weights ,
195
214
topk_ids )
196
215
pplx_output = pplx_cutlass_moe (pgi , dp_size , a , w1 , w2 , w1_scale ,
197
216
w2_scale , topk_weights , topk_ids ,
198
217
a1_scale , out_dtype , per_act_token ,
199
- per_out_ch )
218
+ per_out_ch , group_name )
200
219
201
220
torch_output = chunk_by_rank (torch_output , pgi .rank ,
202
221
pgi .world_size ).to (pplx_output .device )
@@ -207,7 +226,8 @@ def _pplx_moe(
207
226
208
227
torch .testing .assert_close (pplx_output , torch_output , atol = 0.05 , rtol = 0 )
209
228
210
- nvshmem_finalize ()
229
+ if use_internode :
230
+ nvshmem_finalize ()
211
231
212
232
213
233
@pytest .mark .parametrize ("m" , [2 , 224 ])
@@ -218,6 +238,7 @@ def _pplx_moe(
218
238
@pytest .mark .parametrize ("per_act_token" , [True , False ])
219
239
@pytest .mark .parametrize ("per_out_ch" , [True , False ])
220
240
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
241
+ @pytest .mark .parametrize ("use_internode" , [False ])
221
242
@pytest .mark .skipif (
222
243
(lambda x : x is None or not ops .cutlass_group_gemm_supported (x .to_int ()))(
223
244
current_platform .get_device_capability ()),
@@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
232
253
per_act_token : bool ,
233
254
per_out_ch : bool ,
234
255
world_dp_size : tuple [int , int ],
256
+ use_internode : bool ,
235
257
):
236
258
current_platform .seed_everything (7 )
237
259
@@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
284
306
285
307
parallel_launch (world_size , _pplx_moe , dp_size , a , w1_q , w2_q ,
286
308
w1_scale , w2_scale , topk_weights , topk_ids , a_scale1 ,
287
- dtype , a , w1_d , w2_d , per_act_token , per_out_ch )
309
+ dtype , a , w1_d , w2_d , per_act_token , per_out_ch ,
310
+ use_internode )
0 commit comments