Skip to content

[Kernels][Bugfix] Use torch op for all kernels in FusedMoE forward. Add additional testing for cudagraphs. #19717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
(1024 * 128, 1024, 1024),
(32768, 1024, 1024),
# These sizes trigger wrong answers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm could you share some e2e numerical testing results here? this means the current approach still has correctness problems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue with the cutlass kernels themselves (separate from chunking). I'm not sure if the tolerances need to be changed or if there's a real problem with cutlass. @ElizaWszola can probably provide some more insight here.

#(7232, 2048, 5120),
#(40000, 2048, 5120),
]

vllm_config = VllmConfig(parallel_config=ParallelConfig(
Expand Down Expand Up @@ -232,8 +235,10 @@ def test_cutlass_moe_8_bit_no_graph(
topk: int,
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_ch)
Expand Down Expand Up @@ -274,8 +279,10 @@ def test_cutlass_moe_8_bit_cuda_graph(
topk: int,
per_act_token: bool,
per_out_ch: bool,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
dtype = torch.half

Expand Down Expand Up @@ -329,8 +336,10 @@ def test_cutlass_moe_8_bit_EP(
per_act_token: bool,
per_out_channel: bool,
ep_size: int,
monkeypatch,
):
current_platform.seed_everything(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
with set_current_vllm_config(vllm_config):
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
per_out_channel)
Expand Down
209 changes: 157 additions & 52 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

Run `pytest tests/kernels/test_moe.py`.
"""
import functools
from typing import Callable, Optional, Union

import pytest
import torch
from torch.nn import Parameter
Expand Down Expand Up @@ -40,14 +43,84 @@
vllm_config.scheduler_config.max_model_len = 8192


@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
def run_moe_test(
baseline: Union[Callable, torch.Tensor],
moe_fn: Callable,
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
padding: bool = False,
use_compile: bool = False,
use_cudagraph: bool = False,
atol: float = 2e-2,
rtol: float = 0,
) -> torch.Tensor:
if isinstance(baseline, torch.Tensor):
baseline_output = baseline
else:
baseline_output = baseline(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)

# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]

if use_compile:
moe_fn = torch.compile(moe_fn, backend="inductor", fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(score, 0)

test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)

if use_cudagraph:
test_output.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
test_output = moe_fn(a,
w1,
w2,
score,
topk,
global_num_experts=global_num_experts,
expert_map=expert_map)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()

torch.testing.assert_close(test_output,
baseline_output,
atol=atol,
rtol=rtol)

return baseline_output


@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@pytest.mark.parametrize("chunk_size", [8192])
def test_fused_moe(
m: int,
n: int,
Expand All @@ -57,7 +130,17 @@ def test_fused_moe(
ep_size: int,
dtype: torch.dtype,
padding: bool,
chunk_size: int,
monkeypatch,
):
current_platform.seed_everything(7)

monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))

#
# Setup test data
#

a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
Expand All @@ -77,58 +160,70 @@ def test_fused_moe(
else:
e_map = None

m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)

with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)

# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
#
# Setup test functions
#

m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)

def m_fused_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
return m_fused_moe_fn(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map)

fused_moe_fn = functools.partial(fused_moe, renormalize=False)

#
# Run tests
#
runner = functools.partial(
run_moe_test,
a=a,
w1=w1,
w2=w2,
score=score,
topk=topk,
global_num_experts=e,
expert_map=e_map,
padding=padding,
)

triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False

topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)
use_cudagraph = (n >= 1024 and k >= 1024
and current_platform.is_cuda_alike())

torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
rtol=0)
with set_current_vllm_config(vllm_config):
baseline_output = runner(torch_moe, iterative_moe)
runner(baseline_output,
fused_moe_fn,
use_compile=use_compile,
use_cudagraph=use_cudagraph)
runner(baseline_output,
m_fused_moe,
use_compile=use_compile,
use_cudagraph=use_cudagraph)


@pytest.mark.parametrize("m", [1, 32, 222])
Expand Down Expand Up @@ -238,7 +333,12 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
w1_zp=w1_qzeros if has_zp else None,
w2_zp=w2_qzeros if has_zp else None,
block_shape=[0, group_size])
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
torch_output = torch_moe(a,
w1_ref,
w2_ref,
score,
topk,
expert_map=e_map)

torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)

Expand Down Expand Up @@ -546,7 +646,12 @@ def test_fused_marlin_moe(
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)

with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
torch_output = torch_moe(a,
w_ref1,
w_ref2,
score,
topk,
expert_map=e_map)

marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
device=w2.device,
block_size=quant_blocksize)

torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)

torch.testing.assert_close(torch_output,
cutlass_output,
Expand Down
22 changes: 3 additions & 19 deletions tests/kernels/moe/test_pplx_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pytest
import torch

from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
Expand Down Expand Up @@ -164,22 +164,6 @@ def pplx_cutlass_moe(
vllm_config.scheduler_config.max_model_len = 8192


def torch_moe2(a, w1, w2, topk_weight, topk_ids):
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)

return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)


def _pplx_moe(
pgi: ProcessGroupInfo,
dp_size: int,
Expand Down Expand Up @@ -210,8 +194,8 @@ def _pplx_moe(
group_name = cpu_group.group_name

with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
topk_ids)
torch_output = torch_experts(a_full, w1_full, w2_full, topk_weights,
topk_ids)
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
Expand Down
Loading