Skip to content

Commit df84cce

Browse files
authored
perf: use multicast to avoid padding decode request to prefill size (#1555)
### What this PR does / why we need it? perf: use multicast to avoid padding decode request to prefill size ### How was this patch tested? - vLLM version: v0.9.1 - vLLM main: vllm-project/vllm@1fd471e Signed-off-by: boying <897013703@qq.com>
1 parent f08c4f1 commit df84cce

File tree

3 files changed

+81
-34
lines changed

3 files changed

+81
-34
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,9 @@ def apply(
10481048
expert_map=expert_map,
10491049
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
10501050
shared_experts=shared_experts)
1051-
elif fused_moe_state == FusedMoEState.AllGather:
1051+
elif fused_moe_state in [
1052+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
1053+
]:
10521054
return fused_experts(hidden_states=x,
10531055
w1=layer.w13_weight,
10541056
w2=layer.w2_weight,
@@ -1225,6 +1227,22 @@ def __init__(
12251227
self.tp_group = get_tp_group().device_group
12261228
self.quant_method.create_weights(layer=self, **moe_quant_params)
12271229

1230+
def naive_multicast(self, x: torch.Tensor,
1231+
cu_tokens_across_dp_cpu: torch.Tensor):
1232+
assert (len(x.shape) == 2)
1233+
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
1234+
device=x.device,
1235+
dtype=x.dtype)
1236+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1237+
self.dp_rank - 1]
1238+
end = cu_tokens_across_dp_cpu[self.dp_rank]
1239+
buffer[start:end, :].copy_(x)
1240+
for idx in range(self.dp_size):
1241+
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
1242+
end = cu_tokens_across_dp_cpu[idx]
1243+
get_dp_group().broadcast(buffer[start:end, :], idx)
1244+
return buffer
1245+
12281246
def forward(self,
12291247
hidden_states: torch.Tensor,
12301248
router_logits: torch.Tensor,
@@ -1250,9 +1268,10 @@ def forward(self,
12501268
shared_hidden_states = shared_experts(hidden_states)
12511269

12521270
tp_size = get_tensor_model_parallel_world_size()
1253-
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1254-
and fused_moe_state != FusedMoEState.AllGatherEP
1255-
and not replace_allreduce):
1271+
if (tp_size > 1 and fused_moe_state not in [
1272+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1273+
FusedMoEState.NaiveMulticast
1274+
] and not replace_allreduce):
12561275
if num_tokens < tp_size:
12571276
hidden_states = nn.functional.pad(
12581277
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1267,21 +1286,31 @@ def forward(self,
12671286
tp_rank = get_tensor_model_parallel_rank()
12681287
hidden_states = chunk_hidden_states[tp_rank]
12691288
router_logits = chunk_router_logits[tp_rank]
1270-
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
1271-
# NOTE: When in torchair graph, it has been padded in model_runner_v1
1272-
if not self.torchair_graph_enabled or is_prefill:
1273-
attn_metadata = get_forward_context().attn_metadata
1274-
if attn_metadata is not None:
1275-
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
1276-
if num_tokens < max_num_tokens_across_dp:
1277-
hidden_states = nn.functional.pad(
1278-
hidden_states,
1279-
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
1280-
router_logits = nn.functional.pad(
1281-
router_logits,
1282-
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
1283-
hidden_states = get_dp_group().all_gather(hidden_states, 0)
1284-
router_logits = get_dp_group().all_gather(router_logits, 0)
1289+
if self.dp_size > 1:
1290+
if fused_moe_state == FusedMoEState.AllGather:
1291+
# NOTE: When in torchair graph, it has been padded in model_runner_v1
1292+
if not self.torchair_graph_enabled:
1293+
attn_metadata = get_forward_context().attn_metadata
1294+
if attn_metadata is not None:
1295+
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
1296+
if num_tokens < max_num_tokens_across_dp:
1297+
hidden_states = nn.functional.pad(
1298+
hidden_states,
1299+
(0, 0, 0,
1300+
max_num_tokens_across_dp - num_tokens))
1301+
router_logits = nn.functional.pad(
1302+
router_logits,
1303+
(0, 0, 0,
1304+
max_num_tokens_across_dp - num_tokens))
1305+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
1306+
router_logits = get_dp_group().all_gather(router_logits, 0)
1307+
elif fused_moe_state == FusedMoEState.NaiveMulticast:
1308+
cu_tokens_across_dp_cpu = get_forward_context(
1309+
).dp_metadata.cu_tokens_across_dp_cpu
1310+
hidden_states = self.naive_multicast(hidden_states,
1311+
cu_tokens_across_dp_cpu)
1312+
router_logits = self.naive_multicast(router_logits,
1313+
cu_tokens_across_dp_cpu)
12851314

12861315
# Matrix multiply.
12871316
e_hidden_states = self.quant_method.apply(
@@ -1310,28 +1339,40 @@ def forward(self,
13101339
if isinstance(e_hidden_states, tuple):
13111340
e_hidden_states, shared_hidden_states = e_hidden_states
13121341

1313-
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
1314-
and fused_moe_state != FusedMoEState.AllGatherEP
1315-
and not replace_allreduce):
1342+
if (tp_size > 1 and fused_moe_state not in [
1343+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1344+
FusedMoEState.NaiveMulticast
1345+
] and not replace_allreduce):
13161346
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
13171347
self.tp_group)
13181348
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
13191349
if num_tokens < tp_size:
13201350
final_hidden_states = final_hidden_states[:num_tokens]
13211351
dispose_tensor(e_hidden_states)
1322-
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
1323-
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1324-
e_hidden_states,
1325-
"sum",
1326-
scatter_dim=0,
1327-
group=get_dp_group().device_group)
1328-
final_hidden_states = final_hidden_states[:num_tokens]
1329-
dispose_tensor(e_hidden_states)
1352+
elif self.dp_size > 1:
1353+
if fused_moe_state == FusedMoEState.NaiveMulticast:
1354+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
1355+
self.dp_rank - 1]
1356+
end = cu_tokens_across_dp_cpu[self.dp_rank]
1357+
final_hidden_states = get_dp_group().all_reduce(
1358+
e_hidden_states)
1359+
final_hidden_states = final_hidden_states[start:end, :]
1360+
dispose_tensor(e_hidden_states)
1361+
elif fused_moe_state == FusedMoEState.AllGather:
1362+
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
1363+
e_hidden_states,
1364+
"sum",
1365+
scatter_dim=0,
1366+
group=get_dp_group().device_group)
1367+
final_hidden_states = final_hidden_states[:num_tokens]
1368+
dispose_tensor(e_hidden_states)
13301369
else:
13311370
final_hidden_states = e_hidden_states
13321371

1333-
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
1334-
or fused_moe_state == FusedMoEState.AllGatherEP):
1372+
if tp_size > 1 and fused_moe_state in [
1373+
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
1374+
FusedMoEState.NaiveMulticast
1375+
]:
13351376
final_hidden_states = tensor_model_parallel_all_reduce(
13361377
final_hidden_states)
13371378

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,9 @@ def apply(
780780
log2phy=log2phy,
781781
global_redundant_expert_num=global_redundant_expert_num,
782782
shared_experts=shared_experts)
783-
elif fused_moe_state == FusedMoEState.AllGather:
783+
elif fused_moe_state in [
784+
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
785+
]:
784786
return fused_experts(hidden_states=x,
785787
w1=layer.w13_weight,
786788
w1_scale=layer.w13_weight_scale,

vllm_ascend/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ class FusedMoEState(Enum):
419419
All2All = 1
420420
MC2 = 2
421421
AllGatherEP = 3
422+
NaiveMulticast = 4
422423

423424

424425
# TODO(zzzzwwjj): add soc_version to choose branch
@@ -430,7 +431,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool,
430431
and is_deepseek_v3_r1):
431432
return FusedMoEState.AllGatherEP
432433
elif ep_size == 1:
433-
return FusedMoEState.AllGather
434+
if with_prefill:
435+
return FusedMoEState.NaiveMulticast
436+
else:
437+
return FusedMoEState.AllGather
434438
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
435439
elif ep_size < 16 or with_prefill:
436440
return FusedMoEState.All2All

0 commit comments

Comments
 (0)