diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py index 2642c0eac0..4433538cd1 100644 --- a/tests/singlecard/test_ascend_config.py +++ b/tests/singlecard/test_ascend_config.py @@ -114,5 +114,6 @@ def test_ascend_config_load_error(): }, } with VllmRunner("facebook/opt-125m", + enforce_eager=False, additional_config=input_additional_config_fake_2): pass diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2463f17591..2e7d744408 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -53,6 +53,8 @@ def __init__(self, torchair_graph_config): "graph_batch_sizes", []) self.graph_batch_sizes_init = torchair_graph_config.get( "graph_batch_sizes_init", False) + self.enable_multistream_shared_expert = torchair_graph_config.get( + "enable_multistream_shared_expert", False) if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") @@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager): ascend_config = get_ascend_config() # Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode. - if ascend_config.torchair_graph_config.enabled and not enforce_eager: + if ascend_config.torchair_graph_config.enabled and enforce_eager: raise RuntimeError( "Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode." ) diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 0b412782f0..8a1b8d29fb 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -216,6 +216,8 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_shared_expert = \ + ascend_config.torchair_graph_config.enable_multistream_shared_expert def forward( self, @@ -238,6 +240,8 @@ def forward( num_tokens, hidden_size = hidden_states.shape + multistream = self.enable_multistream_shared_expert and not is_prefill + old_hidden_states = hidden_states.clone() if self.tp_size > 1: @@ -259,13 +263,25 @@ def forward( # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + kwargs = {} + if multistream: + kwargs.update({ + "shared_experts": self.shared_experts, + "shared_hidden_states": old_hidden_states + }) + hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, - ) * self.routed_scaling_factor + **kwargs) + + if multistream: + hidden_states, shared_output = hidden_states + + hidden_states = hidden_states * self.routed_scaling_factor if self.tp_size > 1: if self.torchair_graph_enabled: @@ -288,7 +304,8 @@ def forward( hidden_states = hidden_states[:-num_padding_tokens] if self.n_shared_experts is not None: - shared_output = self.shared_experts(old_hidden_states) + if not multistream: + shared_output = self.shared_experts(old_hidden_states) if shared_output is not None: hidden_states = hidden_states + shared_output diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 05eedfc379..6aff62fc62 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -39,19 +39,18 @@ USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, -) -> torch.Tensor: +def fused_experts_with_mc2(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + **kwargs) -> torch.Tensor: global_bs = 0 moe_expert_num = len(expert_map) - kwargs = { + kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, @@ -81,9 +80,9 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage1_kwargs) + kwargs_mc2.update(stage1_kwargs) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] @@ -119,7 +118,7 @@ def fused_experts_with_mc2( down_out_list = torch.cat(down_out_list, dim=0) # moeCombine - kwargs = { + kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, "expand_idx": expand_idx, @@ -141,9 +140,9 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage3_kwargs) + kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) return hidden_states @@ -675,7 +674,8 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name) + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + **kwargs) elif self.torchair_graph_enabled or get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -772,6 +772,8 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_shared_expert = \ + ascend_config.torchair_graph_config.enable_multistream_shared_expert if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -818,7 +820,8 @@ def forward(self, router_logits: torch.Tensor, is_prefill: bool, enable_force_load_balance: bool = False, - top_k=None): + top_k=None, + **kwargs): assert self.quant_method is not None if top_k: @@ -862,7 +865,11 @@ def forward(self, scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance) + enable_force_load_balance=enable_force_load_balance, + **kwargs) + + if self.enable_multistream_shared_expert and not is_prefill: + hidden_states, shared_output = hidden_states if self.dp_size > 1: if VLLM_ENABLE_MC2 and not is_prefill: @@ -886,4 +893,6 @@ def forward(self, if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): hidden_states = tensor_model_parallel_all_reduce(hidden_states) + if self.enable_multistream_shared_expert and not is_prefill: + return hidden_states, shared_output return hidden_states diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 40dbae38f8..e43f25d5bb 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -329,7 +329,7 @@ def apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance) + is_prefill, enable_force_load_balance, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 9d651fbfc8..68d70bc788 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,7 +20,8 @@ import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator +import torchair as tng # type: ignore +from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1) -> torch.Tensor: + group_list_type: int = 1, + **kwargs) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor, else: pertoken_scale = dynamic_scale + shared_experts = kwargs.get('shared_experts', None) + if shared_experts: + shared_gate_up = kwargs.get('shared_gate_up', None) + shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) + shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( + x=shared_gate_up, + weight_scale=shared_experts.gate_up_proj.weight_scale_fp32, + activation_scale=shared_dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=None, + activate_left=True, + quant_mode=1) + # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -100,25 +119,39 @@ def apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=w2_scale.dtype)[0] + + if shared_experts: + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_x, hidden_states) + shared_output = torch_npu.npu_quant_matmul( + shared_x, + shared_experts.down_proj.weight, + shared_experts.down_proj.weight_scale, + pertoken_scale=shared_dynamic_scale, + output_dtype=torch.bfloat16, + ) + if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: + shared_output = tensor_model_parallel_all_reduce(shared_output) + if shared_experts: + return hidden_states, shared_output return hidden_states -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: str = "", -) -> torch.Tensor: +def fused_experts_with_mc2(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", + **kwargs) -> torch.Tensor: global_bs = 0 moe_expert_num = len(expert_map) # hidden_states = hidden_states.bfloat16() - kwargs = { + kwargs_mc2 = { "x": hidden_states, "expert_ids": topk_ids, "expert_shard_type": 0, @@ -149,9 +182,27 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage1_kwargs) + kwargs_mc2.update(stage1_kwargs) + + shared_experts = kwargs.get('shared_experts', None) + if shared_experts: + shared_hidden_states = kwargs.get('shared_hidden_states', None) + with tng.scope.npu_stream_switch('cv'): + tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states) + shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant( + shared_hidden_states) + shared_gate_up = torch_npu.npu_quant_matmul( + shared_x, + shared_experts.gate_up_proj.weight, + shared_experts.gate_up_proj.weight_scale, + output_dtype=torch.int32, + ) + kwargs.update({ + "shared_gate_up": shared_gate_up, + "shared_dynamic_scale": shared_dynamic_scale, + }) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs) + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] @@ -166,10 +217,15 @@ def fused_experts_with_mc2( w2, w2_scale, expert_token_nums, - dynamic_scale=dynamic_scale) + dynamic_scale=dynamic_scale, + **kwargs) + + multi_stream = isinstance(down_out_list, tuple) + if multi_stream: + down_out_list, shared_output = down_out_list # moeCombine - kwargs = { + kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, "expand_idx": expand_idx, @@ -193,10 +249,12 @@ def fused_experts_with_mc2( "tp_world_size": tp_size, "tp_rank_id": tp_rank, } - kwargs.update(stage3_kwargs) + kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs) + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + if multi_stream: + return hidden_states, shared_output return hidden_states @@ -634,7 +692,8 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name) + moe_all_to_all_group_name=self.moe_all_to_all_group_name, + **kwargs) elif self.torchair_graph_enabled or self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight,