Skip to content

Commit c0569db

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Misc] ModularKernel : Perform WeightAndReduce inside TritonExperts & DeepGemmExperts (#20725)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
1 parent 8bb43b9 commit c0569db

File tree

9 files changed

+203
-157
lines changed

9 files changed

+203
-157
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def apply(
260260
hidden_states: torch.Tensor,
261261
w1: torch.Tensor,
262262
w2: torch.Tensor,
263+
topk_weights: torch.Tensor,
263264
topk_ids: torch.Tensor,
264265
activation: str,
265266
global_num_experts: int,
@@ -273,6 +274,7 @@ def apply(
273274
workspace13: torch.Tensor,
274275
workspace2: torch.Tensor,
275276
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
277+
apply_router_weight_on_input: bool,
276278
):
277279
assert expert_tokens_meta is not None
278280
expert_num_tokens = expert_tokens_meta.expert_num_tokens

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -129,30 +129,22 @@ def workspace_shapes(
129129
return self.batched_triton_experts.workspace_shapes(
130130
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
131131

132-
def apply(
133-
self,
134-
output: torch.Tensor,
135-
hidden_states: torch.Tensor,
136-
w1: torch.Tensor,
137-
w2: torch.Tensor,
138-
topk_ids: torch.Tensor,
139-
activation: str,
140-
global_num_experts: int,
141-
expert_map: Optional[torch.Tensor],
142-
w1_scale: Optional[torch.Tensor],
143-
w2_scale: Optional[torch.Tensor],
144-
w1_zp: Optional[torch.Tensor],
145-
w2_zp: Optional[torch.Tensor],
146-
a1q_scale: Optional[torch.Tensor],
147-
a2_scale: Optional[torch.Tensor],
148-
workspace13: torch.Tensor,
149-
workspace2: torch.Tensor,
150-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
151-
):
132+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
133+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
134+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
135+
expert_map: Optional[torch.Tensor],
136+
w1_scale: Optional[torch.Tensor],
137+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
138+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
139+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
140+
workspace2: torch.Tensor,
141+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
142+
apply_router_weight_on_input: bool):
152143
experts = (self.batched_deep_gemm_experts
153144
if self.allow_deep_gemm else self.batched_triton_experts)
154145
assert experts is not None
155-
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
156-
global_num_experts, expert_map, w1_scale, w2_scale,
157-
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
158-
workspace2, expert_tokens_meta)
146+
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
147+
activation, global_num_experts, expert_map, w1_scale,
148+
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
149+
workspace2, expert_tokens_meta,
150+
apply_router_weight_on_input)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -291,26 +291,17 @@ def workspace_shapes(
291291
return (workspace1, workspace2, output,
292292
self.out_dtype if self.out_dtype is not None else a.dtype)
293293

294-
def apply(
295-
self,
296-
output: torch.Tensor,
297-
hidden_states: torch.Tensor,
298-
w1: torch.Tensor,
299-
w2: torch.Tensor,
300-
topk_ids: torch.Tensor,
301-
activation: str,
302-
global_num_experts: int,
303-
expert_map: Optional[torch.Tensor],
304-
w1_scale: Optional[torch.Tensor],
305-
w2_scale: Optional[torch.Tensor],
306-
w1_zp: Optional[torch.Tensor],
307-
w2_zp: Optional[torch.Tensor],
308-
a1q_scale: Optional[torch.Tensor],
309-
a2_scale: Optional[torch.Tensor],
310-
workspace13: torch.Tensor,
311-
workspace2: torch.Tensor,
312-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
313-
):
294+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
295+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
296+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
297+
expert_map: Optional[torch.Tensor],
298+
w1_scale: Optional[torch.Tensor],
299+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
300+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
301+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
302+
workspace2: torch.Tensor,
303+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
304+
apply_router_weight_on_input: bool):
314305
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
315306
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
316307

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
1414
MoEPrepareAndFinalizeNoEP)
1515
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
16-
TopKWeightAndReduceDelegate)
16+
TopKWeightAndReduceContiguous, TopKWeightAndReduceNoOP)
1717
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
1818
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1919
per_token_group_quant_fp8)
@@ -90,8 +90,7 @@ def supports_expert_map(self) -> bool:
9090
return True
9191

9292
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
93-
# Let PrepareAndFinalize::finalize() decide the impl.
94-
return TopKWeightAndReduceDelegate()
93+
return TopKWeightAndReduceNoOP()
9594

9695
def workspace_shapes(
9796
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
@@ -104,9 +103,9 @@ def workspace_shapes(
104103
block_m = self.block_shape[0]
105104
M_sum = (M * topk) + num_experts * (block_m - 1)
106105
M_sum = round_up(M_sum, block_m)
107-
workspace1 = (M_sum, max(N * 2, K))
106+
workspace1 = (M_sum, max(N // 2, K))
108107
workspace2 = (M_sum, max(N, K))
109-
output = (M, topk, K)
108+
output = (M, K)
110109
return (workspace1, workspace2, output, a.dtype)
111110

112111
def apply(
@@ -115,6 +114,7 @@ def apply(
115114
hidden_states: torch.Tensor,
116115
w1: torch.Tensor,
117116
w2: torch.Tensor,
117+
topk_weights: torch.Tensor,
118118
topk_ids: torch.Tensor,
119119
activation: str,
120120
global_num_experts: int,
@@ -128,11 +128,14 @@ def apply(
128128
workspace13: torch.Tensor,
129129
workspace2: torch.Tensor,
130130
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
131+
apply_router_weight_on_input: bool,
131132
):
132133
assert self.block_shape is not None
133134

134135
a1q = hidden_states
135136
_, N, K = w1.size()
137+
M, _ = output.size()
138+
num_topk = topk_ids.size(1)
136139

137140
if global_num_experts == -1:
138141
global_num_experts = w1.size(0)
@@ -159,11 +162,12 @@ def apply(
159162
# Note: M_sum is different than the pre-permuted shape of a1q.
160163
M_sum = a1q.size(0)
161164

162-
mm1_out = _resize_cache(workspace13, (M_sum, N))
163-
act_out = _resize_cache(workspace2, (M_sum, N // 2))
164-
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
165+
mm1_out = _resize_cache(workspace2, (M_sum, N))
166+
act_out = _resize_cache(workspace13, (M_sum, N // 2))
167+
quant_out = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn),
165168
(M_sum, N // 2))
166-
mm2_out = _resize_cache(workspace2, (M_sum, K))
169+
mm2_out = _resize_cache(workspace13, (M_sum, K))
170+
perm_out = _resize_cache(workspace2, (M * num_topk, K))
167171

168172
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
169173
mm1_out, expert_ids)
@@ -179,7 +183,14 @@ def apply(
179183
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
180184
mm2_out, expert_ids)
181185

182-
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
186+
torch.index_select(mm2_out, 0, inv_perm, out=perm_out)
187+
188+
TopKWeightAndReduceContiguous().apply(
189+
output=output,
190+
fused_expert_output=perm_out,
191+
topk_weights=topk_weights,
192+
topk_ids=topk_ids,
193+
apply_router_weight_on_input=apply_router_weight_on_input)
183194

184195

185196
def deep_gemm_moe_fp8(

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,16 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
696696
return t.to(f32) * group_broadcast(scale, t.shape)
697697

698698
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
699-
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
700-
activation: str, global_num_experts: int,
699+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
700+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
701701
expert_map: Optional[torch.Tensor],
702702
w1_scale: Optional[torch.Tensor],
703703
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
704704
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
705705
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
706706
workspace2: torch.Tensor,
707-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
707+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
708+
apply_router_weight_on_input: bool):
708709
assert hidden_states.dim() == 3
709710
assert expert_tokens_meta is not None
710711
expert_num_tokens = expert_tokens_meta.expert_num_tokens
@@ -899,15 +900,16 @@ def workspace_shapes(
899900
return (workspace13, workspace2, output, a.dtype)
900901

901902
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
902-
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
903-
activation: str, global_num_experts: int,
903+
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
904+
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
904905
expert_map: Optional[torch.Tensor],
905906
w1_scale: Optional[torch.Tensor],
906907
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
907908
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
908909
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
909910
workspace2: torch.Tensor,
910-
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
911+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
912+
apply_router_weight_on_input: bool):
911913
# Check constraints.
912914
if self.use_int4_w4a16:
913915
assert hidden_states.size(-1) // 2 == w1.size(2), (

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
2727
MoEPrepareAndFinalizeNoEP)
2828
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
29-
TopKWeightAndReduceDelegate)
29+
TopKWeightAndReduceNoOP)
3030
from vllm.model_executor.layers.fused_moe.utils import (
3131
_resize_cache, moe_kernel_quantize_input)
3232
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
@@ -1606,8 +1606,7 @@ def supports_expert_map(self) -> bool:
16061606
return True
16071607

16081608
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
1609-
# Let PrepareAndFinalize::finalize() decide the impl.
1610-
return TopKWeightAndReduceDelegate()
1609+
return TopKWeightAndReduceNoOP()
16111610

16121611
def workspace_shapes(
16131612
self,
@@ -1620,9 +1619,9 @@ def workspace_shapes(
16201619
global_num_experts: int,
16211620
local_num_experts: int,
16221621
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
1623-
workspace1 = (M, topk, max(N * 2, K))
1624-
workspace2 = (M, topk, N)
1625-
output = (M, topk, K)
1622+
workspace1 = (M, topk, max(N // 2, K))
1623+
workspace2 = (M, topk, max(N, K))
1624+
output = (M, K)
16261625
return (workspace1, workspace2, output, a.dtype)
16271626

16281627
def apply(
@@ -1631,6 +1630,7 @@ def apply(
16311630
hidden_states: torch.Tensor,
16321631
w1: torch.Tensor,
16331632
w2: torch.Tensor,
1633+
topk_weights: torch.Tensor,
16341634
topk_ids: torch.Tensor,
16351635
activation: str,
16361636
global_num_experts: int,
@@ -1644,6 +1644,7 @@ def apply(
16441644
workspace13: torch.Tensor,
16451645
workspace2: torch.Tensor,
16461646
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
1647+
apply_router_weight_on_input: bool,
16471648
):
16481649
# Check constraints.
16491650
if self.use_int4_w4a16:
@@ -1696,37 +1697,39 @@ def apply(
16961697
raise ValueError(
16971698
f"Unsupported compute_type: {hidden_states.dtype}")
16981699

1699-
# We can reuse the memory between these because by the time we need
1700-
# cache3, we're done with cache1
1701-
intermediate_cache1 = _resize_cache(workspace13,
1700+
# Note that the output tensor might be in workspace1
1701+
intermediate_cache1 = _resize_cache(workspace2,
17021702
(num_tokens, top_k_num, N))
1703-
intermediate_cache2 = _resize_cache(workspace2,
1703+
intermediate_cache2 = _resize_cache(workspace13,
17041704
(num_tokens * top_k_num, N // 2))
1705+
intermediate_cache3 = _resize_cache(workspace2,
1706+
(num_tokens, top_k_num, K))
17051707

17061708
sorted_token_ids, expert_ids, num_tokens_post_padded = (
17071709
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
17081710
global_num_experts, expert_map))
17091711

1710-
invoke_fused_moe_kernel(hidden_states,
1711-
w1,
1712-
intermediate_cache1,
1713-
a1q_scale,
1714-
w1_scale,
1715-
w1_zp,
1716-
None,
1717-
sorted_token_ids,
1718-
expert_ids,
1719-
num_tokens_post_padded,
1720-
False,
1721-
top_k_num,
1722-
config,
1723-
compute_type=compute_type,
1724-
use_fp8_w8a8=self.use_fp8_w8a8,
1725-
use_int8_w8a8=self.use_int8_w8a8,
1726-
use_int8_w8a16=self.use_int8_w8a16,
1727-
use_int4_w4a16=self.use_int4_w4a16,
1728-
per_channel_quant=self.per_act_token_quant,
1729-
block_shape=self.block_shape)
1712+
invoke_fused_moe_kernel(
1713+
hidden_states,
1714+
w1,
1715+
intermediate_cache1,
1716+
a1q_scale,
1717+
w1_scale,
1718+
w1_zp,
1719+
None, # topk_weights
1720+
sorted_token_ids,
1721+
expert_ids,
1722+
num_tokens_post_padded,
1723+
False, # mul_routed_weights
1724+
top_k_num,
1725+
config,
1726+
compute_type=compute_type,
1727+
use_fp8_w8a8=self.use_fp8_w8a8,
1728+
use_int8_w8a8=self.use_int8_w8a8,
1729+
use_int8_w8a16=self.use_int8_w8a16,
1730+
use_int4_w4a16=self.use_int4_w4a16,
1731+
per_channel_quant=self.per_act_token_quant,
1732+
block_shape=self.block_shape)
17301733

17311734
self.activation(activation, intermediate_cache2,
17321735
intermediate_cache1.view(-1, N))
@@ -1739,15 +1742,15 @@ def apply(
17391742

17401743
invoke_fused_moe_kernel(qintermediate_cache2,
17411744
w2,
1742-
output,
1745+
intermediate_cache3,
17431746
a2q_scale,
17441747
w2_scale,
17451748
w2_zp,
1746-
None,
1749+
topk_weights,
17471750
sorted_token_ids,
17481751
expert_ids,
17491752
num_tokens_post_padded,
1750-
False,
1753+
not apply_router_weight_on_input,
17511754
1,
17521755
config,
17531756
compute_type=compute_type,
@@ -1758,6 +1761,8 @@ def apply(
17581761
per_channel_quant=self.per_act_token_quant,
17591762
block_shape=self.block_shape)
17601763

1764+
ops.moe_sum(intermediate_cache3, output)
1765+
17611766

17621767
def modular_triton_fused_moe(
17631768
use_fp8_w8a8: bool,

0 commit comments

Comments
 (0)