Skip to content

Commit 805d62c

Browse files
varun-sundar-rabindranathVarun
andauthored
[Misc] DP : Add ExpertTokensMetadata (#20332)
Signed-off-by: Varun <vsundarr@redhat.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
1 parent b7d9e94 commit 805d62c

12 files changed

+117
-79
lines changed

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,11 @@ def apply(
260260
a2_scale: Optional[torch.Tensor],
261261
workspace13: torch.Tensor,
262262
workspace2: torch.Tensor,
263-
expert_num_tokens: Optional[torch.Tensor],
263+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
264264
):
265+
assert expert_tokens_meta is not None
266+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
267+
265268
import deep_gemm as dg
266269
assert hidden_states.ndim == 3
267270
assert self.block_shape is not None
@@ -287,7 +290,6 @@ def apply(
287290
masked_m=expert_num_tokens,
288291
expected_m=expected_m)
289292

290-
assert expert_num_tokens is not None
291293
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
292294
expert_num_tokens)
293295

vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def apply(
129129
a2_scale: Optional[torch.Tensor],
130130
workspace13: torch.Tensor,
131131
workspace2: torch.Tensor,
132-
expert_num_tokens: Optional[torch.Tensor],
132+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
133133
):
134134
experts = (self.batched_deep_gemm_experts
135135
if self.allow_deep_gemm else self.batched_triton_experts)
136136
assert experts is not None
137137
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
138138
global_num_experts, expert_map, w1_scale, w2_scale,
139139
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
140-
workspace2, expert_num_tokens)
140+
workspace2, expert_tokens_meta)

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,17 @@ def apply(
303303
a2_scale: Optional[torch.Tensor],
304304
workspace13: torch.Tensor,
305305
workspace2: torch.Tensor,
306-
expert_num_tokens: Optional[torch.Tensor],
306+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
307307
):
308308
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
309309
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
310+
311+
expert_num_tokens = None
312+
if expert_tokens_meta is not None:
313+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
314+
310315
activation_callable = lambda o, i: self.activation(activation, o, i)
316+
311317
in_dtype = hidden_states.dtype
312318
run_cutlass_moe_fp8(
313319
output, hidden_states, w1, w2, topk_ids, activation_callable,

vllm/model_executor/layers/fused_moe/deep_gemm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def apply(
119119
a2_scale: Optional[torch.Tensor],
120120
workspace13: torch.Tensor,
121121
workspace2: torch.Tensor,
122-
expert_num_tokens: Optional[torch.Tensor],
122+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
123123
):
124124
import deep_gemm as dg
125125
assert self.block_shape is not None

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ def _do_dispatch(self, tokens: torch.Tensor,
6262

6363
has_scales = token_scales is not None
6464

65-
(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
66-
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
65+
(num_tokens_per_rank, num_tokens_per_rdma_rank,
66+
dispatch_expert_num_tokens, is_token_in_rank,
67+
event) = self.buffer.get_dispatch_layout(
6768
topk_idx=rank_topk_ids,
6869
num_experts=num_experts,
6970
previous_event=None,
@@ -83,7 +84,7 @@ def _do_dispatch(self, tokens: torch.Tensor,
8384
num_tokens_per_rank=num_tokens_per_rank,
8485
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
8586
is_token_in_rank=is_token_in_rank,
86-
num_tokens_per_expert=expert_num_tokens,
87+
num_tokens_per_expert=dispatch_expert_num_tokens,
8788
topk_idx=rank_topk_ids,
8889
topk_weights=rank_topk_weights,
8990
# expert_alignment rounds the number of tokens per expert
@@ -115,7 +116,13 @@ def _do_dispatch(self, tokens: torch.Tensor,
115116
num_experts - 1 if self.rank_expert_offset == 0 else 0,
116117
expert_topk_ids + self.rank_expert_offset)
117118

118-
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
119+
# Makes a GPU-CPU copy.
120+
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
121+
# on GPU.
122+
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
123+
expert_num_tokens_per_expert_list, device=expert_x.device)
124+
125+
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
119126
expert_topk_weights)
120127

121128
def prepare(
@@ -129,8 +136,9 @@ def prepare(
129136
expert_map: Optional[torch.Tensor],
130137
apply_router_weight_on_input: bool,
131138
quant_config: FusedMoEQuantConfig,
132-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
133-
Optional[torch.Tensor], Optional[torch.Tensor]]:
139+
) -> tuple[torch.Tensor, Optional[torch.Tensor],
140+
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
141+
Optional[torch.Tensor]]:
134142

135143
if apply_router_weight_on_input:
136144
topk = topk_ids.size(1)
@@ -149,7 +157,7 @@ def prepare(
149157
)
150158
if a1q_scale is not None and a1q_scale.numel() == 1:
151159
a1q_scale = a1q_scale.view(1, 1)
152-
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
160+
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
153161
expert_topk_weights) = self._do_dispatch(
154162
tokens=a1q,
155163
token_scales=a1q_scale,
@@ -159,7 +167,7 @@ def prepare(
159167
else:
160168
# DeepEP kernels only support dispatching per-token-quant
161169
# quantization. dispatch in bfloat16.
162-
(expert_x, _, expert_num_tokens, expert_topk_ids,
170+
(expert_x, _, expert_tokens_meta, expert_topk_ids,
163171
expert_topk_weights) = self._do_dispatch(
164172
tokens=a1,
165173
token_scales=None,
@@ -176,7 +184,7 @@ def prepare(
176184
per_act_token_quant=False,
177185
block_shape=quant_config.block_shape)
178186

179-
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
187+
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
180188
expert_topk_weights)
181189

182190
def _apply_weights_and_reduce(self, num_tokens: int,

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def prepare(
119119
expert_map: Optional[torch.Tensor],
120120
apply_router_weight_on_input: bool,
121121
quant_config: FusedMoEQuantConfig,
122-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
123-
Optional[torch.Tensor], Optional[torch.Tensor]]:
122+
) -> tuple[torch.Tensor, Optional[torch.Tensor],
123+
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
124+
Optional[torch.Tensor]]:
124125

125126
hidden_size = a1.size(1)
126127
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
@@ -158,7 +159,10 @@ def prepare(
158159
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
159160
quant_config.per_act_token_quant, quant_config.block_shape)
160161

161-
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
162+
expert_tokens_meta = mk.ExpertTokensMetadata(
163+
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
164+
165+
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
162166

163167
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
164168
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 32 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,9 @@ def prepare(
505505
expert_map: Optional[torch.Tensor],
506506
apply_router_weight_on_input: bool,
507507
quant_config: FusedMoEQuantConfig,
508-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
509-
Optional[torch.Tensor], Optional[torch.Tensor]]:
508+
) -> tuple[torch.Tensor, Optional[torch.Tensor],
509+
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
510+
Optional[torch.Tensor]]:
510511
assert a1.dim() == 2
511512
assert topk_ids.dim() == 2
512513
assert topk_ids.size(0) == a1.size(0)
@@ -587,7 +588,10 @@ def prepare(
587588

588589
assert b_a1_scale is None or b_a1_scale.ndim == 3
589590

590-
return b_a1, b_a1_scale, tokens_per_expert, None, None
591+
expert_tokens_meta = mk.ExpertTokensMetadata(
592+
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)
593+
594+
return b_a1, b_a1_scale, expert_tokens_meta, None, None
591595

592596
def finalize(
593597
self,
@@ -694,28 +698,19 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
694698
else:
695699
return t.to(f32) * group_broadcast(scale, t.shape)
696700

697-
def apply(
698-
self,
699-
output: torch.Tensor,
700-
hidden_states: torch.Tensor,
701-
w1: torch.Tensor,
702-
w2: torch.Tensor,
703-
topk_ids: torch.Tensor,
704-
activation: str,
705-
global_num_experts: int,
706-
expert_map: Optional[torch.Tensor],
707-
w1_scale: Optional[torch.Tensor],
708-
w2_scale: Optional[torch.Tensor],
709-
w1_zp: Optional[torch.Tensor],
710-
w2_zp: Optional[torch.Tensor],
711-
a1q_scale: Optional[torch.Tensor],
712-
a2_scale: Optional[torch.Tensor],
713-
workspace13: torch.Tensor,
714-
workspace2: torch.Tensor,
715-
expert_num_tokens: Optional[torch.Tensor],
716-
):
701+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
702+
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
703+
activation: str, global_num_experts: int,
704+
expert_map: Optional[torch.Tensor],
705+
w1_scale: Optional[torch.Tensor],
706+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
707+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
708+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
709+
workspace2: torch.Tensor,
710+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
717711
assert hidden_states.dim() == 3
718-
assert expert_num_tokens is not None
712+
assert expert_tokens_meta is not None
713+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
719714

720715
num_local_experts = w1.size(0)
721716
assert num_local_experts == w1.size(0), (
@@ -902,26 +897,16 @@ def workspace_shapes(
902897
output = (num_experts, max_num_tokens * num_dp, K)
903898
return (workspace13, workspace2, output, a.dtype)
904899

905-
def apply(
906-
self,
907-
output: torch.Tensor,
908-
hidden_states: torch.Tensor,
909-
w1: torch.Tensor,
910-
w2: torch.Tensor,
911-
topk_ids: torch.Tensor,
912-
activation: str,
913-
global_num_experts: int,
914-
expert_map: Optional[torch.Tensor],
915-
w1_scale: Optional[torch.Tensor],
916-
w2_scale: Optional[torch.Tensor],
917-
w1_zp: Optional[torch.Tensor],
918-
w2_zp: Optional[torch.Tensor],
919-
a1q_scale: Optional[torch.Tensor],
920-
a2_scale: Optional[torch.Tensor],
921-
workspace13: torch.Tensor,
922-
workspace2: torch.Tensor,
923-
expert_num_tokens: Optional[torch.Tensor],
924-
):
900+
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
901+
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
902+
activation: str, global_num_experts: int,
903+
expert_map: Optional[torch.Tensor],
904+
w1_scale: Optional[torch.Tensor],
905+
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
906+
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
907+
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
908+
workspace2: torch.Tensor,
909+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
925910
# Check constraints.
926911
if self.use_int4_w4a16:
927912
assert hidden_states.size(-1) // 2 == w1.size(2), (
@@ -938,6 +923,9 @@ def apply(
938923
assert hidden_states.dtype in [
939924
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
940925
]
926+
assert expert_tokens_meta is not None
927+
928+
expert_num_tokens = expert_tokens_meta.expert_num_tokens
941929

942930
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
943931
hidden_states, w1, w2, topk_ids)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1630,7 +1630,7 @@ def apply(
16301630
a2_scale: Optional[torch.Tensor],
16311631
workspace13: torch.Tensor,
16321632
workspace2: torch.Tensor,
1633-
expert_num_tokens: Optional[torch.Tensor],
1633+
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
16341634
):
16351635
# Check constraints.
16361636
if self.use_int4_w4a16:

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
45
from enum import Enum
56
from math import prod
67
from typing import Optional, final
@@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
9596
BatchedExperts = "batched_experts",
9697

9798

99+
@dataclass
100+
class ExpertTokensMetadata:
101+
"""
102+
Metadata regarding expert-token routing.
103+
"""
104+
expert_num_tokens: torch.Tensor
105+
expert_num_tokens_cpu: Optional[torch.Tensor]
106+
107+
@staticmethod
108+
def make_from_list(expert_num_tokens_list: list[int],
109+
device: str) -> "ExpertTokensMetadata":
110+
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
111+
device="cpu",
112+
dtype=torch.int32)
113+
return ExpertTokensMetadata(
114+
expert_num_tokens=expert_num_tokens_cpu.to(device,
115+
non_blocking=True),
116+
expert_num_tokens_cpu=expert_num_tokens_cpu)
117+
118+
98119
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
99120
class FusedMoEPrepareAndFinalize(ABC):
100121
"""
@@ -114,8 +135,9 @@ def prepare(
114135
expert_map: Optional[torch.Tensor],
115136
apply_router_weight_on_input: bool,
116137
quant_config: FusedMoEQuantConfig,
117-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
118-
Optional[torch.Tensor], Optional[torch.Tensor]]:
138+
) -> tuple[torch.Tensor, Optional[torch.Tensor],
139+
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
140+
Optional[torch.Tensor]]:
119141
"""
120142
Perform any quantization (and/or) dispatching needed
121143
for this kernel.
@@ -134,7 +156,8 @@ def prepare(
134156
Returns a tuple of:
135157
- quantized + dispatched a.
136158
- quantized + dispatched a1_scales.
137-
- Optional tensor as big as number of local experts that contains the
159+
- Optional ExpertTokensMetadata containing gpu/cpu tensors
160+
as big as the number of local experts with the information about the
138161
number of tokens assigned to each local expert.
139162
- Optional dispatched expert topk IDs
140163
- Optional dispatched expert topk weight
@@ -318,7 +341,7 @@ def apply(
318341
a2_scale: Optional[torch.Tensor],
319342
workspace13: torch.Tensor,
320343
workspace2: torch.Tensor,
321-
expert_num_tokens: Optional[torch.Tensor],
344+
expert_tokens_meta: Optional[ExpertTokensMetadata],
322345
):
323346
"""
324347
This function computes the intermediate result of a Mixture of Experts
@@ -351,8 +374,10 @@ def apply(
351374
must be large enough to hold output of either MoE gemm.
352375
- workspace2 (torch.Tensor): A scratch tensor used for the activation
353376
function.
354-
- expert_num_tokens: An optional tensor containing the number of tokens
355-
assigned to each expert when using batched experts format input.
377+
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
378+
ExpertTokensMetadata object containing gpu/cpu tensors
379+
as big as the number of local experts with the information about the
380+
number of tokens assigned to each local expert.
356381
"""
357382
raise NotImplementedError
358383

@@ -458,7 +483,7 @@ def forward(
458483
if global_num_experts == -1:
459484
global_num_experts = local_num_experts
460485

461-
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
486+
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
462487
_expert_topk_weights) = self.prepare_finalize.prepare(
463488
a1,
464489
a1_scale,
@@ -542,7 +567,7 @@ def forward(
542567
a2_scale=a2_scale,
543568
workspace13=workspace13,
544569
workspace2=workspace2,
545-
expert_num_tokens=expert_num_tokens,
570+
expert_tokens_meta=expert_tokens_meta,
546571
)
547572
else:
548573
# The leading output dimension may not be equal to M, so
@@ -589,7 +614,7 @@ def forward(
589614
a2_scale=curr_a2_scale,
590615
workspace13=workspace13,
591616
workspace2=workspace2,
592-
expert_num_tokens=expert_num_tokens,
617+
expert_tokens_meta=expert_tokens_meta,
593618
)
594619

595620
self.prepare_finalize.finalize(output, fused_out, topk_weights,

0 commit comments

Comments
 (0)