Skip to content

[perf]Support MOE Multi-stream in Deepseek #947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 5, 2025
56 changes: 48 additions & 8 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torch.distributed as dist
import torch_npu
import torchair as tng # type: ignore
import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
Expand Down Expand Up @@ -179,6 +180,12 @@ def __init__(
else:
self.gate.e_score_correction_bias = None

self.enable_cv_parallel = False
additional_config = get_current_vllm_config().additional_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use ascend_config instead now. Note that doc should be updated at the same time.

if additional_config:
self.enable_cv_parallel = additional_config.get(
"enable_cv_parallel", False)

self.experts = AscendFusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
Expand Down Expand Up @@ -241,8 +248,13 @@ def forward(

num_tokens, hidden_size = hidden_states.shape

cv_parallel = self.enable_cv_parallel and not is_prefill

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if not cv_parallel:
shared_output = self.shared_experts(hidden_states)
else:
shared_hidden_states = hidden_states

if self.tp_size > 1:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
Expand All @@ -263,13 +275,41 @@ def forward(
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor
if self.n_shared_experts is not None and cv_parallel:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_hidden_states, router_logits)
dynamic_scale = None
if self.shared_experts.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(
shared_hidden_states)
gate_up = torch_npu.npu_quant_matmul(
x,
self.shared_experts.gate_up_proj.weight,
self.shared_experts.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
else:
gate_up, _ = self.gate_up_proj(shared_hidden_states)

if cv_parallel:
hidden_states, shared_output = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts,
shared_gate_up=gate_up,
shared_dynamic_scale=dynamic_scale)
hidden_states = hidden_states * self.routed_scaling_factor
else:
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor

if self.tp_size > 1:
if self.enable_graph_mode:
Expand Down
63 changes: 44 additions & 19 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
import torch.distributed as dist
import torch_npu
import torchair as tng # type: ignore
from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator,
get_tensor_model_parallel_world_size,
Expand All @@ -38,19 +39,18 @@
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -80,13 +80,20 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_gate_up = kwargs.get('shared_gate_up', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_gate_up, expand_x)
shared_x = shared_experts.act_fn(shared_gate_up)

w1 = w1.transpose(1, 2)
expert_token_nums = torch.cumsum(expert_token_nums,
dim=0,
Expand Down Expand Up @@ -115,10 +122,15 @@ def fused_experts_with_mc2(
group_list=group_list,
)

if shared_experts:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_x, down_out_list)
shared_output = shared_experts.down_proj(shared_x)

down_out_list = torch.cat(down_out_list, dim=0)

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -140,10 +152,12 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

if shared_experts:
return hidden_states, shared_output
return hidden_states


Expand Down Expand Up @@ -677,7 +691,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.enable_graph_mode or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -819,17 +834,21 @@ def __init__(
self.quant_method.create_weights(layer=self, **moe_quant_params)

self.enable_graph_mode = False
self.enable_cv_parallel = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
self.enable_cv_parallel = additional_config.get(
"enable_cv_parallel", False)

def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None):
top_k=None,
**kwargs):
assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -873,7 +892,11 @@ def forward(self,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)
enable_force_load_balance=enable_force_load_balance,
**kwargs)

if self.enable_cv_parallel and not is_prefill:
hidden_states, shared_output = hidden_states

if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
Expand All @@ -897,4 +920,6 @@ def forward(self,
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)

if self.enable_cv_parallel and not is_prefill:
return hidden_states, shared_output
return hidden_states
2 changes: 1 addition & 1 deletion vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance)
is_prefill, enable_force_load_balance, **kwargs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
85 changes: 63 additions & 22 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import torch
import torch.distributed as dist
import torch_npu
import torchair as tng # type: ignore
from vllm.config import get_current_vllm_config
from vllm.distributed import GroupCoordinator
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce

import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_ep_group
Expand All @@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1) -> torch.Tensor:
group_list_type: int = 1,
**kwargs) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj

Expand Down Expand Up @@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
else:
pertoken_scale = dynamic_scale

shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_gate_up = kwargs.get('shared_gate_up', None)
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=shared_gate_up,
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
activation_scale=shared_dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
Expand Down Expand Up @@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]

if shared_experts:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_x, hidden_states)
shared_output = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.down_proj.weight,
shared_experts.down_proj.weight_scale,
pertoken_scale=shared_dynamic_scale,
output_dtype=torch.bfloat16,
)
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_experts:
return hidden_states, shared_output
return hidden_states


def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
) -> torch.Tensor:
def fused_experts_with_mc2(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: str = "",
**kwargs) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
# hidden_states = hidden_states.bfloat16()
kwargs = {
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
Expand Down Expand Up @@ -149,9 +182,9 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
kwargs_mc2.update(stage1_kwargs)

output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
Expand All @@ -166,10 +199,15 @@ def fused_experts_with_mc2(
w2,
w2_scale,
expert_token_nums,
dynamic_scale=dynamic_scale)
dynamic_scale=dynamic_scale,
**kwargs)

multi_stream = isinstance(down_out_list, tuple)
if multi_stream:
down_out_list, shared_output = down_out_list

# moeCombine
kwargs = {
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
Expand All @@ -193,10 +231,12 @@ def fused_experts_with_mc2(
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

if multi_stream:
return hidden_states, shared_output
return hidden_states


Expand Down Expand Up @@ -637,7 +677,8 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs)
elif self.enable_graph_mode or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down