From 10744135b7ce5b5181fb33441d0a2a891132c44d Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 29 May 2025 11:38:41 +0800 Subject: [PATCH 1/6] support moe multistream in deepseek Signed-off-by: David9857 <985700846@qq.com> use additional_config to enable cv parallel Signed-off-by: David9857 <985700846@qq.com> rename kwargs1 in fused_experts_with_mc2 Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/models/deepseek_v2.py | 52 ++++++++++++--- vllm_ascend/ops/fused_moe.py | 16 ++++- vllm_ascend/quantization/quant_config.py | 2 +- vllm_ascend/quantization/w8a8_dynamic.py | 85 ++++++++++++++++++------ 4 files changed, 122 insertions(+), 33 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5e97444157..5649625e6e 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -30,6 +30,7 @@ import torch import torch.distributed as dist import torch_npu +import torchair as tng # type: ignore import vllm.envs as envs from torch import nn from transformers import PretrainedConfig @@ -177,6 +178,12 @@ def __init__( else: self.gate.e_score_correction_bias = None + self.enable_cv_parallel = False + additional_config = get_current_vllm_config().additional_config + if additional_config: + self.enable_cv_parallel = additional_config.get( + "enable_cv_parallel", False) + self.experts = AscendFusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -224,8 +231,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: enable_force_load_balance = False num_tokens, hidden_dim = hidden_states.shape + cv_parallel = self.enable_cv_parallel and not is_prefill + if self.n_shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + if not cv_parallel: + shared_output = self.shared_experts(hidden_states) + else: + shared_hidden_states = hidden_states if self.tp_size > 1: # pass @@ -247,13 +259,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(local_hidden_states) - router_hidden_states = self.experts( - hidden_states=local_hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekV2MoE.top_k, - enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor + if self.n_shared_experts is not None and cv_parallel: + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_hidden_states, router_logits) + x, dynamic_scale = torch_npu.npu_dynamic_quant( + shared_hidden_states) + gate_up = torch_npu.npu_quant_matmul( + x, + self.shared_experts.gate_up_proj.weight, + self.shared_experts.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + + if cv_parallel: + router_hidden_states, shared_output = self.experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=self.shared_experts, + shared_gate_up=gate_up, + shared_dynamic_scale=dynamic_scale) + router_hidden_states = router_hidden_states * self.routed_scaling_factor + else: + router_hidden_states = self.experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + ) * self.routed_scaling_factor if self.tp_size > 1: dist.all_gather(list(chunk_hidden_states), router_hidden_states, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 74a292d576..2b7563166d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -810,12 +810,18 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) + self.enable_cv_parallel = False + if vllm_config.additional_config: + self.enable_cv_parallel = vllm_config.additional_config.get( + "enable_cv_parallel", False) + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_prefill: bool, enable_force_load_balance: bool = False, - top_k=None): + top_k=None, + **kwargs): assert self.quant_method is not None if top_k: @@ -842,7 +848,11 @@ def forward(self, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance) + enable_force_load_balance=enable_force_load_balance, + **kwargs) + + if self.enable_cv_parallel and not is_prefill: + final_hidden_states, shared_output = final_hidden_states if VLLM_ENABLE_MC2 and not is_prefill: ... @@ -851,4 +861,6 @@ def forward(self, final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) + if self.enable_cv_parallel and not is_prefill: + return final_hidden_states, shared_output return final_hidden_states diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 40dbae38f8..e43f25d5bb 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -329,7 +329,7 @@ def apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance) + is_prefill, enable_force_load_balance, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b012f1..37f39e7e69 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,7 +20,8 @@ import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator +import torchair as tng # type: ignore +from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce import vllm_ascend.envs as envs_ascend from vllm_ascend.distributed.parallel_state import get_ep_group @@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: + group_list_type: int = 1, + **kwargs) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], else: pertoken_scale = dynamic_scale + shared_experts = kwargs.get('shared_experts', None) + if shared_experts: + shared_gate_up = kwargs.get('shared_gate_up', None) + shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) + with tng.scope.npu_stream_switch('1'): + tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) + shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=shared_gate_up, + weight_scale=shared_experts.gate_up_proj.weight_scale_fp32, + activation_scale=shared_dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -96,25 +115,39 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], group_type=0, group_list=group_list, output_dtype=w2_scale.dtype)[0] + + if shared_experts: + with tng.scope.npu_stream_switch('1'): + tng.scope.npu_wait_tensor(shared_x, hidden_states) + shared_output = torch_npu.npu_quant_matmul( + shared_x, + shared_experts.down_proj.weight, + shared_experts.down_proj.weight_scale, + pertoken_scale=shared_dynamic_scale, + output_dtype=torch.bfloat16, + ) + if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: + shared_output = tensor_model_parallel_all_reduce(shared_output) + if shared_experts: + return hidden_states, shared_output return hidden_states -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: str = "", -) -> torch.Tensor: +def fused_experts_with_mc2(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", + **kwargs) -> torch.Tensor: global_bs = 0 moe_expert_num = len(expert_map) # hidden_states = hidden_states.bfloat16() - kwargs = { + kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, @@ -145,9 +178,9 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage1_kwargs) + kwargs_mc2.update(stage1_kwargs) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] @@ -165,10 +198,15 @@ def fused_experts_with_mc2( w2, w2_scale, expert_token_nums, - dynamic_scale=dynamic_scale) + dynamic_scale=dynamic_scale, + **kwargs) + + multi_stream = isinstance(down_out_list, tuple) + if multi_stream: + down_out_list, shared_output = down_out_list # moeCombine - kwargs = { + kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, "expand_idx": expand_idx, @@ -192,10 +230,12 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage3_kwargs) + kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + if multi_stream: + return hidden_states, shared_output return hidden_states @@ -633,7 +673,8 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name) + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + **kwargs) elif self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, From 3630856497bdc53a5e1f972d8364ded6c98a86f9 Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 29 May 2025 22:54:12 +0800 Subject: [PATCH 2/6] support cv parallel for float model Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/models/deepseek_v2.py | 20 ++++++---- vllm_ascend/ops/fused_moe.py | 49 ++++++++++++++++-------- vllm_ascend/quantization/w8a8_dynamic.py | 4 +- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 5649625e6e..615d1ac660 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -262,14 +262,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None and cv_parallel: with tng.scope.npu_stream_switch('cv'): tng.scope.npu_wait_tensor(shared_hidden_states, router_logits) - x, dynamic_scale = torch_npu.npu_dynamic_quant( - shared_hidden_states) - gate_up = torch_npu.npu_quant_matmul( - x, - self.shared_experts.gate_up_proj.weight, - self.shared_experts.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) + dynamic_scale = None + if self.shared_experts.is_dynamic_quant: + x, dynamic_scale = torch_npu.npu_dynamic_quant( + shared_hidden_states) + gate_up = torch_npu.npu_quant_matmul( + x, + self.shared_experts.gate_up_proj.weight, + self.shared_experts.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + else: + gate_up, _ = self.gate_up_proj(shared_hidden_states) if cv_parallel: router_hidden_states, shared_output = self.experts( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 2b7563166d..420a2a76e0 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -20,6 +20,7 @@ import torch import torch.distributed as dist import torch_npu +import torchair as tng # type: ignore from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size, @@ -38,19 +39,18 @@ USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, -) -> torch.Tensor: +def fused_experts_with_mc2(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + **kwargs) -> torch.Tensor: global_bs = 0 moe_expert_num = len(expert_map) - kwargs = { + kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, @@ -81,13 +81,20 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage1_kwargs) + kwargs_mc2.update(stage1_kwargs) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] + shared_experts = kwargs.get('shared_experts', None) + if shared_experts: + shared_gate_up = kwargs.get('shared_gate_up', None) + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_gate_up, expand_x) + shared_x = shared_experts.act_fn(shared_gate_up) + w1 = w1.transpose(1, 2) expert_token_nums = torch.cumsum(expert_token_nums, dim=0, @@ -116,10 +123,15 @@ def fused_experts_with_mc2( group_list=group_list, ) + if shared_experts: + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_x, down_out_list) + shared_output = shared_experts.down_proj(shared_x) + down_out_list = torch.cat(down_out_list, dim=0) # moeCombine - kwargs = { + kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, "expand_idx": expand_idx, @@ -141,10 +153,12 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage3_kwargs) + kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + if shared_experts: + return hidden_states, shared_output return hidden_states @@ -664,7 +678,8 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name) + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + **kwargs) elif get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 37f39e7e69..e170f5dc45 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -74,7 +74,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], if shared_experts: shared_gate_up = kwargs.get('shared_gate_up', None) shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) - with tng.scope.npu_stream_switch('1'): + with tng.scope.npu_stream_switch('cv'): tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( x=shared_gate_up, @@ -117,7 +117,7 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], output_dtype=w2_scale.dtype)[0] if shared_experts: - with tng.scope.npu_stream_switch('1'): + with tng.scope.npu_stream_switch('cv'): tng.scope.npu_wait_tensor(shared_x, hidden_states) shared_output = torch_npu.npu_quant_matmul( shared_x, From 3511331c9f5deb19cfa8e69829e5aa3acf44f946 Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 5 Jun 2025 20:32:48 +0800 Subject: [PATCH 3/6] refactor in deepseek moe Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/models/deepseek_v2.py | 67 +++++++++--------------- vllm_ascend/ops/fused_moe.py | 10 ++-- vllm_ascend/quantization/w8a8_dynamic.py | 18 +++++++ 3 files changed, 47 insertions(+), 48 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 8f49c1b402..7e573e24c5 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -180,12 +180,6 @@ def __init__( else: self.gate.e_score_correction_bias = None - self.enable_cv_parallel = False - additional_config = get_current_vllm_config().additional_config - if additional_config: - self.enable_cv_parallel = additional_config.get( - "enable_cv_parallel", False) - self.experts = AscendFusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -222,10 +216,13 @@ def __init__( self.params_dtype = torch.get_default_dtype() self.enable_graph_mode = False + self.enable_multistream_shared_expert = False additional_config = get_current_vllm_config().additional_config if additional_config: self.enable_graph_mode = additional_config.get( "enable_graph_mode", False) + self.enable_multistream_shared_expert = additional_config.get( + "enable_multistream_shared_expert", False) def forward( self, @@ -248,10 +245,10 @@ def forward( num_tokens, hidden_size = hidden_states.shape - cv_parallel = self.enable_cv_parallel and not is_prefill + multistream = self.enable_multistream_shared_expert and not is_prefill if self.n_shared_experts is not None: - if not cv_parallel: + if not multistream: shared_output = self.shared_experts(hidden_states) else: shared_hidden_states = hidden_states @@ -275,41 +272,25 @@ def forward( # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - if self.n_shared_experts is not None and cv_parallel: - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_hidden_states, router_logits) - dynamic_scale = None - if self.shared_experts.is_dynamic_quant: - x, dynamic_scale = torch_npu.npu_dynamic_quant( - shared_hidden_states) - gate_up = torch_npu.npu_quant_matmul( - x, - self.shared_experts.gate_up_proj.weight, - self.shared_experts.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - else: - gate_up, _ = self.gate_up_proj(shared_hidden_states) - - if cv_parallel: - hidden_states, shared_output = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekV2MoE.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=self.shared_experts, - shared_gate_up=gate_up, - shared_dynamic_scale=dynamic_scale) - hidden_states = hidden_states * self.routed_scaling_factor - else: - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=CustomDeepseekV2MoE.top_k, - enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor + kwargs = {} + if multistream: + kwargs.update({ + "shared_experts": self.shared_experts, + "shared_hidden_states": shared_hidden_states + }) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + **kwargs) + + if multistream: + hidden_states, shared_output = hidden_states + + hidden_states = hidden_states * self.routed_scaling_factor if self.tp_size > 1: if self.enable_graph_mode: diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 70db7590a0..e444758b41 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -834,13 +834,13 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) self.enable_graph_mode = False - self.enable_cv_parallel = False + self.enable_multistream_shared_expert = False additional_config = get_current_vllm_config().additional_config if additional_config: self.enable_graph_mode = additional_config.get( "enable_graph_mode", False) - self.enable_cv_parallel = additional_config.get( - "enable_cv_parallel", False) + self.enable_multistream_shared_expert = additional_config.get( + "enable_multistream_shared_expert", False) def forward(self, hidden_states: torch.Tensor, @@ -895,7 +895,7 @@ def forward(self, enable_force_load_balance=enable_force_load_balance, **kwargs) - if self.enable_cv_parallel and not is_prefill: + if self.enable_multistream_shared_expert and not is_prefill: hidden_states, shared_output = hidden_states if self.dp_size > 1: @@ -920,6 +920,6 @@ def forward(self, if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states) - if self.enable_cv_parallel and not is_prefill: + if self.enable_multistream_shared_expert and not is_prefill: return hidden_states, shared_output return hidden_states diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index fdc36da9d4..ecbf1a07f0 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -184,6 +184,24 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, } kwargs_mc2.update(stage1_kwargs) + shared_experts = kwargs.get('shared_experts', None) + if shared_experts: + shared_hidden_states = kwargs.get('shared_hidden_states', None) + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states) + shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant( + shared_hidden_states) + shared_gate_up = torch_npu.npu_quant_matmul( + shared_x, + shared_experts.gate_up_proj.weight, + shared_experts.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + kwargs.update({ + "shared_gate_up": shared_gate_up, + "shared_dynamic_scale": shared_dynamic_scale, + }) + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ From 415394cb8aa04254d92300b2136565e7d65837d9 Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 5 Jun 2025 21:58:20 +0800 Subject: [PATCH 4/6] remove cv parallel for float model Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/models/deepseek_v2.py | 4 ++-- vllm_ascend/ops/fused_moe.py | 15 --------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 7e573e24c5..696449288d 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -286,10 +286,10 @@ def forward( top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, **kwargs) - + if multistream: hidden_states, shared_output = hidden_states - + hidden_states = hidden_states * self.routed_scaling_factor if self.tp_size > 1: diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index e444758b41..8b9eac2b62 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -20,7 +20,6 @@ import torch import torch.distributed as dist import torch_npu -import torchair as tng # type: ignore from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_world_size, @@ -87,13 +86,6 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] - shared_experts = kwargs.get('shared_experts', None) - if shared_experts: - shared_gate_up = kwargs.get('shared_gate_up', None) - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_gate_up, expand_x) - shared_x = shared_experts.act_fn(shared_gate_up) - w1 = w1.transpose(1, 2) expert_token_nums = torch.cumsum(expert_token_nums, dim=0, @@ -122,11 +114,6 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, group_list=group_list, ) - if shared_experts: - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_x, down_out_list) - shared_output = shared_experts.down_proj(shared_x) - down_out_list = torch.cat(down_out_list, dim=0) # moeCombine @@ -156,8 +143,6 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - if shared_experts: - return hidden_states, shared_output return hidden_states From 354ff2c83249d3dd4229388331ac7383303fb43b Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 5 Jun 2025 22:09:03 +0800 Subject: [PATCH 5/6] update torchair config Signed-off-by: David9857 <985700846@qq.com> bugfix Signed-off-by: David9857 <985700846@qq.com> --- vllm_ascend/ascend_config.py | 4 +++- vllm_ascend/models/deepseek_v2.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2463f17591..2e7d744408 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config): "graph_batch_sizes", []) self.graph_batch_sizes_init = torchair_graph_config.get( "graph_batch_sizes_init", False) + self.enable_multistream_shared_expert = torchair_graph_config.get( + "enable_multistream_shared_expert", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") @@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager): ascend_config = get_ascend_config() # Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode. - if ascend_config.torchair_graph_config.enabled and not enforce_eager: + if ascend_config.torchair_graph_config.enabled and enforce_eager: raise RuntimeError( "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." ) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 4caa22cf57..8a1b8d29fb 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -30,7 +30,6 @@ import torch import torch.distributed as dist import torch_npu -import torchair as tng # type: ignore import vllm.envs as envs from torch import nn from transformers import PretrainedConfig From 4ae80fd13502c510d1f3b5abfa9ba1f40ca369a7 Mon Sep 17 00:00:00 2001 From: David9857 <985700846@qq.com> Date: Thu, 5 Jun 2025 22:45:39 +0800 Subject: [PATCH 6/6] fix ut for ascend config Signed-off-by: David9857 <985700846@qq.com> --- tests/singlecard/test_ascend_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py index 2642c0eac0..4433538cd1 100644 --- a/tests/singlecard/test_ascend_config.py +++ b/tests/singlecard/test_ascend_config.py @@ -114,5 +114,6 @@ def test_ascend_config_load_error(): }, } with VllmRunner("facebook/opt-125m", + enforce_eager=False, additional_config=input_additional_config_fake_2): pass