Skip to content

Commit f0c98ca

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] MoE ModularKernel : Introduce TopKWeightAndReduce (#20648)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 574ad60 commit f0c98ca

14 files changed

+297
-59
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
3333
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3434
FusedMoEModularKernel)
35+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
36+
TopKWeightAndReduceDelegate)
3537
from vllm.platforms import current_platform
3638
from vllm.utils import round_up
3739

@@ -371,6 +373,7 @@ def pplx_prepare_finalize(
371373
chunk_topk_weight,
372374
chunk_topk_ids,
373375
False,
376+
weight_and_reduce_impl=TopKWeightAndReduceDelegate(),
374377
)
375378

376379
torch.cuda.synchronize()

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
88
from vllm.logger import init_logger
99
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
10+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
11+
TopKWeightAndReduceDelegate)
1012
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1113
from vllm.triton_utils import tl, triton
1214

@@ -217,6 +219,10 @@ def supports_chunking(self) -> bool:
217219
def supports_expert_map(self) -> bool:
218220
return False
219221

222+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
223+
# Let PrepareAndFinalize::finalize() decide the impl.
224+
return TopKWeightAndReduceDelegate()
225+
220226
def workspace_shapes(
221227
self,
222228
a: torch.Tensor,

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,25 @@ def supports_expert_map(self) -> bool:
8888
return ((bdge is None or bdge.supports_expert_map())
8989
and (bte is None or bte.supports_expert_map()))
9090

91+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
92+
bdge = self.batched_deep_gemm_experts
93+
bte = self.batched_triton_experts
94+
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
95+
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
96+
is_bdge_war = bdge_war is not None
97+
is_bte_war = bte_war is not None
98+
99+
if is_bdge_war and is_bte_war:
100+
assert bdge_war == bte_war, (
101+
"Both implementations should agree on WeightAndReduce impls. "
102+
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}")
103+
104+
if bdge_war is not None:
105+
return bdge_war
106+
107+
assert bte_war is not None
108+
return bte_war
109+
91110
def workspace_shapes(
92111
self,
93112
a: torch.Tensor,

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1212
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1313
MoEPrepareAndFinalizeNoEP)
14+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
15+
TopKWeightAndReduceDelegate)
1416
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
1517
_fp8_quantize,
1618
_resize_cache)
@@ -255,6 +257,10 @@ def supports_chunking(self) -> bool:
255257
def supports_expert_map(self) -> bool:
256258
return not self.use_batched_format
257259

260+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
261+
# Let PrepareAndFinalize::finalize() decide the impl.
262+
return TopKWeightAndReduceDelegate()
263+
258264
def workspace_shapes(
259265
self,
260266
a: torch.Tensor,

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
_moe_permute)
1313
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1414
MoEPrepareAndFinalizeNoEP)
15+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
16+
TopKWeightAndReduceDelegate)
1517
from vllm.model_executor.layers.fused_moe.utils import (
1618
_resize_cache, per_token_group_quant_fp8)
1719
from vllm.utils import has_deep_gemm, round_up
@@ -85,6 +87,10 @@ def supports_chunking(self) -> bool:
8587
def supports_expert_map(self) -> bool:
8688
return True
8789

90+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
91+
# Let PrepareAndFinalize::finalize() decide the impl.
92+
return TopKWeightAndReduceDelegate()
93+
8894
def workspace_shapes(
8995
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
9096
topk: int, global_num_experts: int, local_num_experts: int

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import torch
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9-
from vllm import _custom_ops as ops
109
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
10+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
11+
TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate)
1112
from vllm.model_executor.layers.fused_moe.utils import (
1213
moe_kernel_quantize_input)
1314

@@ -187,45 +188,25 @@ def prepare(
187188
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
188189
expert_topk_weights)
189190

190-
def _apply_weights_and_reduce(self, num_tokens: int,
191-
fused_expert_output: torch.Tensor,
192-
topk_weights: torch.Tensor,
193-
apply_router_weight_on_input: bool,
194-
output_dtype: torch.dtype):
195-
196-
hidden_dim = fused_expert_output.size(-1)
197-
if fused_expert_output.ndim == 2:
198-
fused_expert_output = fused_expert_output.view(
199-
num_tokens, -1, hidden_dim)
200-
201-
if not apply_router_weight_on_input:
202-
# The DeepEP combine kernels don't do the topk weight
203-
# multiplication. We multiply the weights locally.
204-
m_x_topk = fused_expert_output.size(0)
205-
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
206-
207-
out = torch.empty((num_tokens, hidden_dim),
208-
device=fused_expert_output.device,
209-
dtype=output_dtype)
210-
ops.moe_sum(fused_expert_output, out)
211-
212-
return out
213-
214191
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
215192
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
216-
apply_router_weight_on_input: bool) -> None:
193+
apply_router_weight_on_input: bool,
194+
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
217195

218196
assert self.handle is not None
219197

220198
# fused_expert_output can have 0 tokens - This happens when none of the
221199
# tokens from the all2all reach this EP rank.
222200
if fused_expert_output.numel() != 0:
223-
fused_expert_output = self._apply_weights_and_reduce(
224-
num_tokens=topk_ids.size(0),
201+
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
202+
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
203+
fused_expert_output = weight_and_reduce_impl.apply(
204+
output=None,
225205
fused_expert_output=fused_expert_output,
226206
topk_weights=topk_weights,
207+
topk_ids=topk_ids,
227208
apply_router_weight_on_input=apply_router_weight_on_input,
228-
output_dtype=output.dtype)
209+
)
229210

230211
combined_x, _, event = self.buffer.combine(
231212
x=fused_expert_output,

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
99
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
10+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
11+
TopKWeightAndReduceDelegate)
1012
from vllm.model_executor.layers.fused_moe.utils import (
1113
moe_kernel_quantize_input, normalize_batched_scales_shape)
1214

@@ -166,8 +168,11 @@ def prepare(
166168

167169
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
168170
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
169-
apply_router_weight_on_input: bool) -> None:
170-
171+
apply_router_weight_on_input: bool,
172+
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
173+
assert isinstance(
174+
weight_and_reduce_impl, TopKWeightAndReduceDelegate
175+
), ("Weight application and reduction happens in the combine kernel.")
171176
assert self.handle is not None
172177

173178
combine_topk_weights = topk_weights

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1212
from vllm.model_executor.layers.fused_moe.fused_moe import (
1313
get_config_dtype_str, try_get_optimal_moe_config)
14+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
15+
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
1416
from vllm.model_executor.layers.fused_moe.utils import (
1517
_resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape,
1618
normalize_scales_shape)
@@ -600,25 +602,17 @@ def finalize(
600602
topk_weights: torch.Tensor,
601603
topk_ids: torch.Tensor,
602604
apply_router_weight_on_input: bool,
605+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
603606
) -> None:
604-
num_tokens = topk_ids.size(0)
605-
num_local_experts = fused_expert_output.size(0)
606-
K = fused_expert_output.size(-1)
607-
assert output.size(0) == num_tokens and output.size(1) == K
608-
609-
output.fill_(0)
610-
611-
first_expert = num_local_experts * self.rank
612-
last_expert = first_expert + num_local_experts
613-
614-
for expert_id in range(first_expert, last_expert):
615-
matching_tokens = topk_ids == expert_id
616-
topks = torch.any(matching_tokens, dim=1).flatten()
617-
rows = torch.count_nonzero(topks)
618-
rhs = fused_expert_output[expert_id - first_expert, :rows, :]
619-
if not apply_router_weight_on_input:
620-
rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1))
621-
output[topks] = output[topks] + rhs
607+
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
608+
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
609+
weight_and_reduce_impl.apply(
610+
output=output,
611+
fused_expert_output=fused_expert_output,
612+
topk_weights=topk_weights,
613+
topk_ids=topk_ids,
614+
apply_router_weight_on_input=apply_router_weight_on_input,
615+
)
622616

623617

624618
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -670,6 +664,10 @@ def supports_chunking(self) -> bool:
670664
def supports_expert_map(self) -> bool:
671665
return False
672666

667+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
668+
# Let PrepareAndFinalize::finalize() decide the impl.
669+
return TopKWeightAndReduceDelegate()
670+
673671
def workspace_shapes(
674672
self,
675673
a: torch.Tensor,
@@ -877,6 +875,10 @@ def supports_chunking(self) -> bool:
877875
def supports_expert_map(self) -> bool:
878876
return False
879877

878+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
879+
# Let PrepareAndFinalize::finalize() decide the impl.
880+
return TopKWeightAndReduceDelegate()
881+
880882
def workspace_shapes(
881883
self,
882884
a: torch.Tensor,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
moe_align_block_size)
2626
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
2727
MoEPrepareAndFinalizeNoEP)
28+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
29+
TopKWeightAndReduceDelegate)
2830
from vllm.model_executor.layers.fused_moe.utils import (
2931
_resize_cache, moe_kernel_quantize_input)
3032
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -1596,6 +1598,10 @@ def supports_chunking(self) -> bool:
15961598
def supports_expert_map(self) -> bool:
15971599
return True
15981600

1601+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
1602+
# Let PrepareAndFinalize::finalize() decide the impl.
1603+
return TopKWeightAndReduceDelegate()
1604+
15991605
def workspace_shapes(
16001606
self,
16011607
a: torch.Tensor,

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#
2424
# [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]
2525
#
26-
# Each component will be independent of the others except for
26+
# Each component will be independent of (but may inform) the others except for
2727
# [Quantize-Dispatch] and `[Combine] (see below). The components can then be
2828
# mixed and matched with so that DP+EP can be supported easily for multiple
2929
# MoE kernel implementations.
@@ -32,13 +32,19 @@
3232
# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
3333
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
3434
# The prepare method must take care of any needed quantization and the
35-
# finalize method must apply weights and do the final reduction of the output.
35+
# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
36+
# may apply weights and/or do the final reduction of the output.
3637
# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
37-
# MoE operation. One important feature to note is that this class does not
38-
# apply topk weights or reduce the final output.
38+
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
39+
# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
40+
# the weight application and/or reduction. The class communicates this
41+
# to [Finalize] via a TopKWeightAndReduce object.
3942
# * FusedMoEModularKernel - an interface class that combines a
4043
# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
4144
# provide the standard fused MoE kernel interface.
45+
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
46+
# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
47+
# on to [Finalize].
4248
#
4349
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
4450
# class `FusedMoEPrepareAndFinalize` since they could use collective
@@ -117,6 +123,24 @@ def make_from_list(expert_num_tokens_list: list[int],
117123
expert_num_tokens_cpu=expert_num_tokens_cpu)
118124

119125

126+
class TopKWeightAndReduce(ABC):
127+
"""
128+
An abstract base class for weight application and reduction implementations.
129+
"""
130+
131+
@abstractmethod
132+
def apply(self, output: Optional[torch.Tensor],
133+
fused_expert_output: torch.Tensor, topk_weights: torch.Tensor,
134+
topk_ids: torch.Tensor,
135+
apply_router_weight_on_input: bool) -> torch.Tensor:
136+
"""
137+
Apply topk_weights to the fused_experts_outputs and/or reduce.
138+
If an output tensor is not passed, it will be created in the
139+
function.
140+
"""
141+
raise NotImplementedError
142+
143+
120144
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
121145
class FusedMoEPrepareAndFinalize(ABC):
122146
"""
@@ -173,6 +197,7 @@ def finalize(
173197
topk_weights: torch.Tensor,
174198
topk_ids: torch.Tensor,
175199
apply_router_weight_on_input: bool,
200+
weight_and_reduce_impl: TopKWeightAndReduce,
176201
) -> None:
177202
"""
178203
Perform any combine plus apply weights and perform a reduction on the
@@ -184,6 +209,8 @@ def finalize(
184209
- topk_ids: The topk_ids.
185210
- apply_router_weight_on_input: When False, apply the weights to
186211
fused_expert_output.
212+
- weight_and_reduce_impl: An optional TopKWeightAndReduce
213+
implementation.
187214
"""
188215
raise NotImplementedError
189216

@@ -323,6 +350,9 @@ def enable_chunking(self):
323350
return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \
324351
self.supports_chunking()
325352

353+
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
354+
raise NotImplementedError
355+
326356
@abstractmethod
327357
def apply(
328358
self,
@@ -702,7 +732,9 @@ def forward(
702732
a2_scale=a2_scale,
703733
expert_tokens_meta=expert_tokens_meta)
704734

705-
self.prepare_finalize.finalize(output, fused_out, topk_weights,
706-
topk_ids, apply_router_weight_on_input)
735+
self.prepare_finalize.finalize(
736+
output, fused_out, topk_weights, topk_ids,
737+
apply_router_weight_on_input,
738+
self.fused_experts.finalize_weight_and_reduce_impl())
707739

708740
return output

0 commit comments

Comments
 (0)