Skip to content

[Kernels] MoE refactor #19636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 72 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
e8ab05a
turn try_get_optimal_moe_config into an op so it can be torch.compiled
bnellnm Jun 16, 2025
e60fc9e
lint
bnellnm Jun 16, 2025
515b60e
torch.compile tests
bnellnm Jun 16, 2025
b8c64a1
add tests
bnellnm Jun 16, 2025
f2916ac
add compiler + cudagraph tests
bnellnm Jun 16, 2025
9daa832
tests
bnellnm Jun 16, 2025
d269e47
reduce number of compile/cudagraph tests
bnellnm Jun 16, 2025
e4a4952
lint
bnellnm Jun 16, 2025
debd465
fix lint
bnellnm Jun 17, 2025
2681694
replace import that lint removed
bnellnm Jun 17, 2025
960f861
fixes
bnellnm Jun 17, 2025
7fef821
lint
bnellnm Jun 17, 2025
3c74170
opify at a higher level
bnellnm Jun 18, 2025
43441cd
de-opify deepgemm kernels
bnellnm Jun 18, 2025
813b66c
remove cruft
bnellnm Jun 18, 2025
010d904
MoE refactoring
bnellnm Jun 12, 2025
1b0fad3
make FusedMoEModularKernel a Leaf
bnellnm Jun 13, 2025
584de04
make FusedMoEModularKernel a Leaf
bnellnm Jun 13, 2025
c42f742
fix format
bnellnm Jun 13, 2025
8f91f36
config stuff + add more tests
bnellnm Jun 14, 2025
4f52150
fixes
bnellnm Jun 14, 2025
2c8ec1d
wip test
bnellnm Jun 16, 2025
0d39be3
fix mergea
bnellnm Jun 16, 2025
17097ea
disable buggy fp8 tests
bnellnm Jun 17, 2025
f5973ab
fixes
bnellnm Jun 17, 2025
c822322
more lint
bnellnm Jun 17, 2025
12b1df4
more lint
bnellnm Jun 17, 2025
c68fe52
merge
bnellnm Jun 18, 2025
af060d4
fix merge
bnellnm Jun 18, 2025
763f590
fix deep gemm test
bnellnm Jun 18, 2025
b9c027a
add supports_expert_map method + cleanup select_gemm_impl methods
bnellnm Jun 19, 2025
4407618
lint
bnellnm Jun 19, 2025
e9a66cb
revert random linter changes
bnellnm Jun 19, 2025
762394c
fix comments + lint
bnellnm Jun 20, 2025
e7973d7
remove some logging
bnellnm Jun 20, 2025
5fc344c
remove unused method
bnellnm Jun 20, 2025
72097bb
try to fix lint
bnellnm Jun 20, 2025
d1b83ba
add some asserts to make lint happy
bnellnm Jun 20, 2025
7422357
try again with the linter
bnellnm Jun 20, 2025
d1928ad
review comments + fixes
bnellnm Jun 25, 2025
7546a29
review comments + test fixes
bnellnm Jun 26, 2025
2061d68
fix test_mixtral_moe + bump up some tolerances
bnellnm Jun 26, 2025
96b08fc
remove duplicate test setup code. fix some tests, some still failing
bnellnm Jun 26, 2025
a6e7d47
lint
bnellnm Jun 26, 2025
149f7b7
more lint
bnellnm Jun 26, 2025
4b4ae50
fix lint
bnellnm Jun 26, 2025
07a2599
more linter fixes
bnellnm Jun 26, 2025
a26eab4
appease yapf/isort gods
bnellnm Jun 26, 2025
fd4ffd8
fix test_deepep_moe.py
bnellnm Jun 27, 2025
455a6ce
move deepep_utils -> parallel_utils
bnellnm Jun 27, 2025
3caa61f
fix test_block_fp8.py test
bnellnm Jun 27, 2025
bb5d8e9
more lint nonsense
bnellnm Jun 27, 2025
7684225
Fix incorrect per_act_token
ElizaWszola Jun 27, 2025
f188691
fix merge
bnellnm Jun 28, 2025
579af67
fix lint nonsense
bnellnm Jun 28, 2025
a76d2ef
fix merge
bnellnm Jun 28, 2025
550cc3b
fix merge
bnellnm Jun 28, 2025
d466524
fix deepep ht tests
bnellnm Jun 29, 2025
525affc
review comments, reduce test combinations, cleanup test code, etc.
bnellnm Jun 30, 2025
d2b6682
some quantization tweaks
bnellnm Jul 1, 2025
0972e75
fix weight config
bnellnm Jul 1, 2025
5b154fa
fix comment
bnellnm Jul 1, 2025
012af37
fix stupid bug
bnellnm Jul 1, 2025
9e17fb0
more fixes
bnellnm Jul 1, 2025
d81a46b
fix
bnellnm Jul 1, 2025
63837ad
fix lint
bnellnm Jul 1, 2025
8d8ed0a
fix LM Eval Small Models test failure
bnellnm Jul 1, 2025
9a9b8e9
shut lint up for now
bnellnm Jul 1, 2025
e635a37
bump up int8 tolerance a tiny bit
bnellnm Jul 1, 2025
db33d8f
fix merge
bnellnm Jul 2, 2025
347a7b7
fix messed up config setup
bnellnm Jul 2, 2025
86224d0
one more fix
bnellnm Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def run_cutlass_moe(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_repeats: int,
):
for _ in range(num_repeats):
Expand All @@ -124,7 +125,8 @@ def run_cutlass_moe(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)

def run_cutlass_from_graph(
Expand All @@ -148,7 +150,8 @@ def run_cutlass_from_graph(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)

def run_triton_from_graph(
Expand Down Expand Up @@ -227,6 +230,7 @@ def replay_graph(graph, num_repeats):
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
Expand Down Expand Up @@ -287,12 +291,13 @@ def replay_graph(graph, num_repeats):
w2_scale,
topk_weights,
topk_ids,
per_act_token,
num_warmup,
)

results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
190 changes: 190 additions & 0 deletions tests/kernels/moe/parallel_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional

import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec

from vllm.utils import get_open_port

has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)

## Parallel Processes Utils

P = ParamSpec("P")


@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device


def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)

try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()


def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
)


## DeepEP specific utils


@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int


@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool


def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):

import deep_ep

# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)


def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):

import deep_ep

# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)

buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)

return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)


def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)

assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
Loading