Skip to content

Commit 29fa5ca

Browse files
authored
[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent b2d9be6 commit 29fa5ca

15 files changed

+458
-396
lines changed

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
(224, 1024, 1536),
3030
(224, 3072, 1024),
3131
(224, 3072, 1536),
32+
(1024 * 128, 1024, 1024),
3233
]
3334

3435
vllm_config = VllmConfig(parallel_config=ParallelConfig(

tests/kernels/moe/test_moe.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
1616
from vllm.config import VllmConfig, set_current_vllm_config
1717
from vllm.model_executor.layers.fused_moe import fused_moe
18-
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
18+
from vllm.model_executor.layers.fused_moe.fused_moe import (
19+
fused_topk, modular_triton_fused_moe)
1920
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
2021
fused_moe as iterative_moe)
2122
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@@ -76,6 +77,13 @@ def test_fused_moe(
7677
else:
7778
e_map = None
7879

80+
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
81+
use_int8_w8a8=False,
82+
use_int8_w8a16=False,
83+
use_int4_w4a16=False,
84+
per_channel_quant=False,
85+
block_shape=None)
86+
7987
with set_current_vllm_config(vllm_config):
8088
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
8189
iterative_output = iterative_moe(a,
@@ -103,7 +111,20 @@ def test_fused_moe(
103111
expert_map=e_map,
104112
renormalize=False)
105113

114+
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
115+
m_triton_output = m_fused_moe(a,
116+
w1,
117+
w2,
118+
topk_weights,
119+
topk_ids,
120+
global_num_experts=e,
121+
expert_map=e_map)
122+
106123
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
124+
torch.testing.assert_close(m_triton_output,
125+
torch_output,
126+
atol=2e-2,
127+
rtol=0)
107128
torch.testing.assert_close(iterative_output,
108129
torch_output,
109130
atol=2e-2,

tests/kernels/moe/test_pplx_cutlass_moe.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from typing import Optional
5+
46
import pytest
57
import torch
68

7-
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
89
from vllm import _custom_ops as ops
910
from vllm.config import VllmConfig, set_current_vllm_config
1011
from vllm.model_executor.layers.activation import SiluAndMul
@@ -14,6 +15,8 @@
1415
FusedMoEModularKernel)
1516
from vllm.platforms import current_platform
1617

18+
from .deepep_utils import ProcessGroupInfo, parallel_launch
19+
1720
try:
1821
from pplx_kernels import AllToAll
1922
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
@@ -64,6 +67,7 @@ def pplx_cutlass_moe(
6467
out_dtype,
6568
per_act_token: bool,
6669
per_out_ch: bool,
70+
group_name: Optional[str],
6771
):
6872
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
6973
PplxPrepareAndFinalize)
@@ -84,7 +88,7 @@ def pplx_cutlass_moe(
8488
else:
8589
scale_elems = (hidden_dim + block_size - 1) // block_size
8690

87-
ata = AllToAll.internode(
91+
args = dict(
8892
max_num_tokens=max_num_tokens,
8993
num_experts=num_experts,
9094
experts_per_token=topk,
@@ -96,6 +100,12 @@ def pplx_cutlass_moe(
96100
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
97101
)
98102

103+
if group_name is None:
104+
ata = AllToAll.internode(**args)
105+
else:
106+
args["group_name"] = group_name
107+
ata = AllToAll.intranode(**args)
108+
99109
w1 = w1.to(device)
100110
w2 = w2.to(device)
101111
w1_scale = w1_scale.to(device)
@@ -113,7 +123,10 @@ def pplx_cutlass_moe(
113123
)
114124

115125
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
116-
out_dtype, per_act_token, per_out_ch)
126+
out_dtype,
127+
per_act_token,
128+
per_out_ch,
129+
use_batched_format=True)
117130

118131
fused_cutlass_experts = FusedMoEModularKernel(
119132
prepare_finalize,
@@ -184,19 +197,25 @@ def _pplx_moe(
184197
w2_full: torch.Tensor,
185198
per_act_token: bool,
186199
per_out_ch: bool,
200+
use_internode: bool,
187201
):
188-
uid = nvshmem_get_unique_id(
189-
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
190-
torch.distributed.broadcast(uid, src=0)
191-
nvshmem_init(uid, pgi.rank, pgi.world_size)
202+
if use_internode:
203+
uid = nvshmem_get_unique_id(
204+
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
205+
torch.distributed.broadcast(uid, src=0)
206+
nvshmem_init(uid, pgi.rank, pgi.world_size)
207+
else:
208+
group_ranks = list(range(pgi.world_size))
209+
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
210+
group_name = cpu_group.group_name
192211

193212
with set_current_vllm_config(vllm_config):
194213
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
195214
topk_ids)
196215
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
197216
w2_scale, topk_weights, topk_ids,
198217
a1_scale, out_dtype, per_act_token,
199-
per_out_ch)
218+
per_out_ch, group_name)
200219

201220
torch_output = chunk_by_rank(torch_output, pgi.rank,
202221
pgi.world_size).to(pplx_output.device)
@@ -207,7 +226,8 @@ def _pplx_moe(
207226

208227
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
209228

210-
nvshmem_finalize()
229+
if use_internode:
230+
nvshmem_finalize()
211231

212232

213233
@pytest.mark.parametrize("m", [2, 224])
@@ -218,6 +238,7 @@ def _pplx_moe(
218238
@pytest.mark.parametrize("per_act_token", [True, False])
219239
@pytest.mark.parametrize("per_out_ch", [True, False])
220240
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
241+
@pytest.mark.parametrize("use_internode", [False])
221242
@pytest.mark.skipif(
222243
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
223244
current_platform.get_device_capability()),
@@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
232253
per_act_token: bool,
233254
per_out_ch: bool,
234255
world_dp_size: tuple[int, int],
256+
use_internode: bool,
235257
):
236258
current_platform.seed_everything(7)
237259

@@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
284306

285307
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
286308
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
287-
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
309+
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
310+
use_internode)

0 commit comments

Comments
 (0)