Skip to content

[perf]Support MOE Multi-stream in Deepseek #947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 5, 2025
1 change: 1 addition & 0 deletions tests/singlecard/test_ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ def test_ascend_config_load_error():
},
}
with VllmRunner("facebook/opt-125m",
enforce_eager=False,
additional_config=input_additional_config_fake_2):
pass
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config):
"graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False)
self.enable_multistream_shared_expert = torchair_graph_config.get(
"enable_multistream_shared_expert", False)

if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
Expand Down Expand Up @@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager):
ascend_config = get_ascend_config()

# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
if ascend_config.torchair_graph_config.enabled and enforce_eager:
raise RuntimeError(
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
)
Expand Down
21 changes: 19 additions & 2 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert

def forward(
self,
Expand All @@ -238,6 +240,8 @@ def forward(

num_tokens, hidden_size = hidden_states.shape

multistream = self.enable_multistream_shared_expert and not is_prefill

old_hidden_states = hidden_states.clone()

if self.tp_size > 1:
Expand All @@ -259,13 +263,25 @@ def forward(
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

kwargs = {}
if multistream:
kwargs.update({
"shared_experts": self.shared_experts,
"shared_hidden_states": old_hidden_states
})

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor
**kwargs)

if multistream:
hidden_states, shared_output = hidden_states

hidden_states = hidden_states * self.routed_scaling_factor

if self.tp_size > 1:
if self.torchair_graph_enabled:
Expand All @@ -288,7 +304,8 @@ def forward(
hidden_states = hidden_states[:-num_padding_tokens]

if self.n_shared_experts is not None:
shared_output = self.shared_experts(old_hidden_states)
if not multistream:
shared_output = self.shared_experts(old_hidden_states)

if shared_output is not None:
hidden_states = hidden_states + shared_output
Expand Down
47 changes: 28 additions & 19 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,18 @@
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -81,9 +80,9 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
Expand Down Expand Up @@ -119,7 +118,7 @@ def fused_experts_with_mc2(
down_out_list = torch.cat(down_out_list, dim=0)

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -141,9 +140,9 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

return hidden_states

Expand Down Expand Up @@ -675,7 +674,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -772,6 +772,8 @@ def __init__(

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert

if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
Expand Down Expand Up @@ -818,7 +820,8 @@ def forward(self,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None):
top_k=None,
**kwargs):
assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -862,7 +865,11 @@ def forward(self,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)
enable_force_load_balance=enable_force_load_balance,
**kwargs)

if self.enable_multistream_shared_expert and not is_prefill:
hidden_states, shared_output = hidden_states

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
Expand All @@ -886,4 +893,6 @@ def forward(self,
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

if self.enable_multistream_shared_expert and not is_prefill:
return hidden_states, shared_output
return hidden_states
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance)
is_prefill, enable_force_load_balance, **kwargs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
103 changes: 81 additions & 22 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
import torchair as tng # type: ignore
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
Expand All @@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
group_list_type: int = 1,
**kwargs) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj

Expand Down Expand Up @@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
else:
pertoken_scale = dynamic_scale

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_gate_up = kwargs.get('shared_gate_up', None)
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=shared_gate_up,
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
activation_scale=shared_dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
Expand Down Expand Up @@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]

if shared_experts:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_x, hidden_states)
shared_output = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.down_proj.weight,
shared_experts.down_proj.weight_scale,
pertoken_scale=shared_dynamic_scale,
output_dtype=torch.bfloat16,
)
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_experts:
return hidden_states, shared_output
return hidden_states


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
# hidden_states = hidden_states.bfloat16()
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -149,9 +182,27 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_hidden_states = kwargs.get('shared_hidden_states', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
shared_hidden_states)
shared_gate_up = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.gate_up_proj.weight,
shared_experts.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
kwargs.update({
"shared_gate_up": shared_gate_up,
"shared_dynamic_scale": shared_dynamic_scale,
})

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
Expand All @@ -166,10 +217,15 @@ def fused_experts_with_mc2(
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
dynamic_scale=dynamic_scale,
**kwargs)

multi_stream = isinstance(down_out_list, tuple)
if multi_stream:
down_out_list, shared_output = down_out_list

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -193,10 +249,12 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

if multi_stream:
return hidden_states, shared_output
return hidden_states


Expand Down Expand Up @@ -634,7 +692,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down