diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5e1be9fad9..ae305ecc10 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm_ascend import envs @@ -98,6 +99,7 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None + mc2_mask: Optional[torch.Tensor] = None @dataclass @@ -213,6 +215,13 @@ def __init__(self, self.cos_cache = None self.sin_cache = None + def generate_activate_mask(self, actual_seqs_num, batch_size): + mc2_mask = torch.zeros(batch_size, + dtype=torch.bool, + device=current_platform.device_type) + mc2_mask[:actual_seqs_num].fill_(True) + return mc2_mask + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at @@ -355,6 +364,7 @@ def build_torchair_graph_dummy( self.rope_dim, dtype=self.runner.dtype, device=device) + mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -364,7 +374,8 @@ def build_torchair_graph_dummy( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -545,6 +556,8 @@ def build( else: seq_lens_list = seq_lens.tolist() cos, sin = None, None + mc2_mask = self.generate_activate_mask( + num_actual_tokens, num_reqs + num_reqs_pad_size) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -555,7 +568,8 @@ def build( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fe1164fd4d..a85877a56f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -122,6 +122,7 @@ def fused_experts_with_mc2( moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None, is_torchair: bool = False, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -138,6 +139,9 @@ def fused_experts_with_mc2( need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -161,6 +165,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) @@ -230,6 +238,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -944,6 +956,7 @@ def apply( fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -954,7 +967,8 @@ def apply( expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled) + is_torchair=self.torchair_graph_enabled, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1155,6 +1169,9 @@ def forward(self, if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) + attn_metadata = get_forward_context().attn_metadata + mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None + tp_size = get_tensor_model_parallel_world_size() if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: if num_tokens < tp_size: @@ -1162,6 +1179,9 @@ def forward(self, hidden_states, (0, 0, 0, tp_size - num_tokens)) router_logits = nn.functional.pad( router_logits, (0, 0, 0, tp_size - num_tokens)) + if mc2_mask is not None: + mc2_mask = nn.functional.pad(mc2_mask, + (0, tp_size - num_tokens)) chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, dim=0) @@ -1171,6 +1191,11 @@ def forward(self, tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] + + if mc2_mask is not None: + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: @@ -1209,6 +1234,7 @@ def forward(self, and self.enable_multistream_moe and not is_prefill else None, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + mc2_mask=mc2_mask, ) if shared_experts: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a9938c14f2..3561675095 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -215,6 +215,7 @@ def fused_experts_with_mc2( w2_scale_bias: torch.Tensor = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] @@ -233,6 +234,9 @@ def fused_experts_with_mc2( need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num else: @@ -260,6 +264,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) @@ -310,6 +318,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -791,6 +803,7 @@ def apply( topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -807,7 +820,8 @@ def apply( shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale) + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 11ee1ab526..12f502a4d7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1668,6 +1668,8 @@ def _dummy_run( attn_metadata.decode.input_positions) torch._dynamo.mark_static(attn_metadata.decode.sin) torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static( + attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance(