Skip to content

Commit d1b9b99

Browse files
committed
MoE refactoring
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 3597b06 commit d1b9b99

15 files changed

+208
-76
lines changed

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
from typing import Any, Optional
66

77
from vllm.model_executor.layers.fused_moe.layer import (
8-
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
8+
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, MoEConfig)
9+
from vllm.model_executor.layers.fused_moe.modular_kernel import (
10+
FusedMoEPrepareAndFinalize,
11+
FusedMoEPermuteExpertsUnpermute,
12+
FusedMoEActivationFormat)
913
from vllm.triton_utils import HAS_TRITON
1014

1115
_config: Optional[dict[str, Any]] = None
@@ -28,6 +32,10 @@ def get_config() -> Optional[dict[str, Any]]:
2832
"FusedMoE",
2933
"FusedMoEMethodBase",
3034
"FusedMoeWeightScaleSupported",
35+
"FusedMoEPermuteExpertsUnpermute",
36+
"FusedMoEActivationFormat",
37+
"FusedMoEPrepareAndFinalize",
38+
"MoEConfig",
3139
"override_config",
3240
"get_config",
3341
]
@@ -37,10 +45,20 @@ def get_config() -> Optional[dict[str, Any]]:
3745
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
3846
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
3947
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
40-
cutlass_moe_fp4, cutlass_moe_fp8)
48+
cutlass_moe_fp4, cutlass_moe_fp8, CutlassExpertsFp8)
4149
from vllm.model_executor.layers.fused_moe.fused_moe import (
4250
TritonExperts, fused_experts, fused_moe, fused_topk,
4351
get_config_file_name, grouped_topk)
52+
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
53+
BatchedTritonExperts)
54+
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
55+
DeepGemmExperts)
56+
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
57+
BatchedDeepGemmExperts)
58+
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
59+
TritonOrDeepGemmExperts)
60+
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
61+
BatchedTritonOrDeepGemmExperts)
4462

4563
__all__ += [
4664
"fused_moe",
@@ -50,5 +68,11 @@ def get_config() -> Optional[dict[str, Any]]:
5068
"grouped_topk",
5169
"cutlass_moe_fp8",
5270
"cutlass_moe_fp4",
71+
"CutlassExpertsFp8",
5372
"TritonExperts",
73+
"BatchedTritonExperts",
74+
"DeepGemmExperts",
75+
"BatchedDeepGemmExperts",
76+
"TritonOrDeepGemmExperts",
77+
"BatchedTritonOrDeepGemmExperts",
5478
]

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,15 @@ def __init__(self,
6767
assert (self.batched_deep_gemm_experts is not None
6868
or self.batched_triton_experts is not None)
6969

70+
@property
71+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
72+
if self.batched_triton_experts is not None:
73+
assert self.batched_deep_gemm_experts is None or self.batched_deep_gemm_experts.activation_formats == self.batched_triton_experts.activation_formats
74+
return self.batched_triton_experts.activation_formats
75+
else:
76+
assert self.batched_deep_gemm_experts is not None
77+
return self.batched_deep_gemm_experts.activation_formats
78+
7079
def supports_chunking(self) -> bool:
7180
bdge = self.batched_deep_gemm_experts
7281
bte = self.batched_triton_experts

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def __init__(
219219
self.per_out_ch = per_out_ch
220220
self.use_batched_format = use_batched_format
221221

222+
@property
223+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
224+
return (mk.FusedMoEActivationFormat.Standard,
225+
mk.FusedMoEActivationFormat.Standard)
226+
222227
def supports_chunking(self) -> bool:
223228
return not self.use_batched_format
224229

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def __init__(self):
7070
super().__init__()
7171
self.block_shape = deep_gemm_block_shape()
7272

73+
@property
74+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
75+
return (mk.FusedMoEActivationFormat.Standard,
76+
mk.FusedMoEActivationFormat.Standard)
77+
7378
def supports_chunking(self) -> bool:
7479
return True
7580

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def __init__(self,
3939
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
4040
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
4141

42+
@property
43+
def activation_format(self) -> mk.FusedMoEActivationFormat:
44+
return mk.FusedMoEActivationFormat.Standard
45+
4246
def max_num_tokens_per_rank(self) -> Optional[int]:
4347
return None
4448

@@ -130,20 +134,20 @@ def prepare(
130134
a1: torch.Tensor,
131135
a1_scale: Optional[torch.Tensor],
132136
a2_scale: Optional[torch.Tensor],
133-
rank_topk_weights: torch.Tensor,
134-
rank_topk_ids: torch.Tensor,
137+
topk_weights: torch.Tensor,
138+
topk_ids: torch.Tensor,
135139
num_experts: int,
136140
expert_map: Optional[torch.Tensor],
137141
apply_router_weight_on_input: bool,
138142
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
139143
Optional[torch.Tensor], Optional[torch.Tensor]]:
140144

141145
if apply_router_weight_on_input:
142-
topk = rank_topk_ids.size(1)
146+
topk = topk_ids.size(1)
143147
# TODO: this only works for topK=1, will need to update for topK>1
144148
assert topk == 1, (
145149
"apply_router_weight_on_input is only implemented for topk=1")
146-
a1 = a1 * rank_topk_weights.to(a1.dtype)
150+
a1 = a1 * topk_weights.to(a1.dtype)
147151

148152
# Check if there is a block_shape / or if we can infer the quantization
149153
# schemes from the scales.
@@ -165,8 +169,8 @@ def prepare(
165169
expert_topk_weights) = self._do_dispatch(
166170
tokens=a1q,
167171
token_scales=a1q_scale,
168-
rank_topk_ids=rank_topk_ids,
169-
rank_topk_weights=rank_topk_weights,
172+
rank_topk_ids=topk_ids,
173+
rank_topk_weights=topk_weights,
170174
num_experts=num_experts)
171175
else:
172176
# DeepEP kernels only support dispatching per-token-quant
@@ -175,8 +179,8 @@ def prepare(
175179
expert_topk_weights) = self._do_dispatch(
176180
tokens=a1,
177181
token_scales=None,
178-
rank_topk_ids=rank_topk_ids,
179-
rank_topk_weights=rank_topk_weights,
182+
rank_topk_ids=topk_ids,
183+
rank_topk_weights=topk_weights,
180184
num_experts=num_experts)
181185
# quantize now
182186
expert_x_scale = None

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def __init__(self,
5959
# combine function.
6060
self.handle = None
6161

62+
@property
63+
def activation_format(self) -> mk.FusedMoEActivationFormat:
64+
return mk.FusedMoEActivationFormat.BatchedExperts
65+
6266
def max_num_tokens_per_rank(self) -> Optional[int]:
6367
return self.max_tokens_per_rank
6468

@@ -118,8 +122,8 @@ def prepare(
118122
a1: torch.Tensor,
119123
a1_scale: Optional[torch.Tensor],
120124
a2_scale: Optional[torch.Tensor],
121-
rank_topk_weights: torch.Tensor,
122-
rank_topk_ids: torch.Tensor,
125+
topk_weights: torch.Tensor,
126+
topk_ids: torch.Tensor,
123127
num_experts: int,
124128
expert_map: Optional[torch.Tensor],
125129
apply_router_weight_on_input: bool,
@@ -142,16 +146,16 @@ def prepare(
142146
"low_latency kernels doesn't support dispatching per-token scales")
143147

144148
if apply_router_weight_on_input:
145-
topk = rank_topk_ids.size(1)
149+
topk = topk_ids.size(1)
146150
# TODO: this only works for topK=1, will need to update for topK>1
147151
assert topk == 1, (
148152
"apply_router_weight_on_input is only implemented for topk=1")
149-
a1 = a1 * rank_topk_weights.to(a1.dtype)
153+
a1 = a1 * topk_weights.to(a1.dtype)
150154

151155
# Dispatch
152156
expert_x, expert_num_tokens, self.handle, event, hook = \
153157
self.buffer.low_latency_dispatch(a1,
154-
rank_topk_ids,
158+
topk_ids,
155159
self.max_tokens_per_rank,
156160
num_experts,
157161
use_fp8=self.use_fp8_dispatch,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,10 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
395395
self.rank = rank
396396
self.max_num_tokens = max_num_tokens
397397

398+
@property
399+
def activation_format(self) -> mk.FusedMoEActivationFormat:
400+
return mk.FusedMoEActivationFormat.BatchedExperts
401+
398402
def max_num_tokens_per_rank(self) -> Optional[int]:
399403
return self.max_num_tokens
400404

@@ -510,6 +514,11 @@ def __init__(
510514
self.world_size = world_size
511515
self.dp_size = dp_size
512516

517+
@property
518+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
519+
return (mk.FusedMoEActivationFormat.BatchedExperts,
520+
mk.FusedMoEActivationFormat.BatchedExperts)
521+
513522
def supports_chunking(self) -> bool:
514523
return False
515524

@@ -615,6 +624,11 @@ def __init__(
615624
assert not use_int4_w4a16, "NYI"
616625
assert self.block_shape is None, "NYI"
617626

627+
@property
628+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
629+
return (mk.FusedMoEActivationFormat.BatchedExperts,
630+
mk.FusedMoEActivationFormat.BatchedExperts)
631+
618632
def supports_chunking(self) -> bool:
619633
return False
620634

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,11 @@ def __init__(
15421542
use_int4_w4a16=use_int4_w4a16)
15431543
self.per_channel_quant = per_channel_quant
15441544

1545+
@property
1546+
def activation_formats(self) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
1547+
return (mk.FusedMoEActivationFormat.Standard,
1548+
mk.FusedMoEActivationFormat.Standard)
1549+
15451550
def supports_chunking(self) -> bool:
15461551
return True
15471552

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from vllm.forward_context import ForwardContext, get_forward_context
2424
from vllm.logger import init_logger
2525
from vllm.model_executor.custom_op import CustomOp
26+
from .modular_kernel import (FusedMoEModularKernel,
27+
FusedMoEPermuteExpertsUnpermute,
28+
FusedMoEPrepareAndFinalize)
2629
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
2730
is_rocm_aiter_moe_enabled)
2831
from vllm.model_executor.layers.quantization.base_config import (
@@ -38,9 +41,6 @@
3841
if current_platform.is_cuda_alike():
3942
from .fused_batched_moe import BatchedTritonExperts
4043
from .fused_moe import TritonExperts, fused_experts
41-
from .modular_kernel import (FusedMoEModularKernel,
42-
FusedMoEPermuteExpertsUnpermute,
43-
FusedMoEPrepareAndFinalize)
4444
if has_pplx:
4545
from .pplx_prepare_finalize import PplxPrepareAndFinalize
4646
if has_deepep:
@@ -304,9 +304,8 @@ def init_prepare_finalize(self, moe: MoEConfig,
304304
act_quant_block_size = quant_config.weight_block_size
305305
quant_dtype = torch.float8_e4m3fn
306306

307-
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
308-
DeepEPHTPrepareAndFinalize,
309-
DeepEPLLPrepareAndFinalize]] = None
307+
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
308+
310309
if moe.use_pplx_kernels:
311310
all_to_all_args = dict(
312311
max_num_tokens=moe.max_num_tokens,
@@ -399,8 +398,10 @@ def init_prepare_finalize(self, moe: MoEConfig,
399398
)
400399

401400
def select_gemm_impl(
402-
self, prepare_finalize: FusedMoEPrepareAndFinalize,
403-
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
401+
self,
402+
prepare_finalize: FusedMoEPrepareAndFinalize,
403+
moe: MoEConfig
404+
) -> FusedMoEPermuteExpertsUnpermute:
404405
# based on the all2all implementation, select the appropriate
405406
# gemm implementation
406407
raise NotImplementedError(
@@ -446,23 +447,23 @@ def __init__(self, moe: MoEConfig):
446447
else:
447448
self.rocm_aiter_fused_experts = None # type: ignore
448449

449-
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
450-
moe: Optional[MoEConfig]):
450+
def select_gemm_impl(
451+
self,
452+
prepare_finalize: FusedMoEPrepareAndFinalize,
453+
moe: MoEConfig
454+
) -> FusedMoEPermuteExpertsUnpermute:
451455

452456
assert self.fused_experts == fused_experts
453457

454458
all2all_manager = get_ep_group().device_communicator.all2all_manager
455459
assert all2all_manager is not None
456460

457-
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
458-
459-
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
460-
) is not None
461-
if use_batched_experts:
461+
if prepare_finalize.activation_format == FusedMoeActivationFormat.BatchedExperts:
462462
logger.debug("BatchedTritonExperts %s", self.moe)
463463
assert self.moe.dp_size == all2all_manager.dp_world_size
464-
experts = BatchedTritonExperts(
464+
return BatchedTritonExperts(
465465
max_num_tokens=self.moe.max_num_tokens,
466+
# TODO (bnell): Fix this mess
466467
world_size=all2all_manager.world_size,
467468
# dp_size actually means tp_size, bug in pplx kernels
468469
dp_size=all2all_manager.tp_group.world_size,
@@ -475,15 +476,14 @@ def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
475476
)
476477
else:
477478
logger.debug("TritonExperts %s", self.moe)
478-
experts = TritonExperts(
479+
return TritonExperts(
479480
use_fp8_w8a8=False,
480481
use_int8_w8a8=False,
481482
use_int8_w8a16=False,
482483
use_int4_w4a16=False,
483484
block_shape=None,
484485
per_channel_quant=False,
485486
)
486-
return experts
487487

488488
def create_weights(self, layer: torch.nn.Module, num_experts: int,
489489
hidden_size: int, intermediate_size_per_partition: int,

0 commit comments

Comments
 (0)