From 88c31da00e02f5ad5486184385deb8cce27c6c6e Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 7 Jul 2025 11:00:22 +0800 Subject: [PATCH 1/4] add mc2 mask Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 11 +++++++++++ vllm_ascend/ops/fused_moe.py | 23 ++++++++++++++++++++++- vllm_ascend/quantization/w8a8_dynamic.py | 12 ++++++++++++ vllm_ascend/worker/model_runner_v1.py | 2 ++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 816d93c028..11b079c469 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm_ascend import envs @@ -94,6 +95,7 @@ class AscendMLADecodeMetadata: seq_lens_list: list[int] actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None + mc2_mask: Optional[torch.Tensor] = None @dataclass @@ -206,6 +208,11 @@ def __init__(self, ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + def generate_active_mask(self, actual_seqs_num, batch_size): + mc2_mask = torch.zeros(batch_size, dtype=torch.bool, device=current_platform.device_type) + mc2_mask[:actual_seqs_num].fill_(True) + return mc2_mask + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at @@ -336,6 +343,7 @@ def build_torchair_graph_dummy( else: attn_state = AscendAttentionState.DecodeOnly num_decode_tokens = 1 + mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -344,6 +352,7 @@ def build_torchair_graph_dummy( max_seq_lens=1, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], + mc2_mask=mc2_mask, ) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, @@ -500,6 +509,7 @@ def build( num_reqs_pad_size] else: seq_lens_list = seq_lens.tolist() + mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -509,6 +519,7 @@ def build( max_seq_lens=max_seq_lens, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, + mc2_mask=mc2_mask, ) return self.metadata_cls( # type: ignore diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fe1164fd4d..a3ad96519b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -122,6 +122,7 @@ def fused_experts_with_mc2( moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None, is_torchair: bool = False, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -137,6 +138,9 @@ def fused_experts_with_mc2( # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 moe_expert_num = len(expert_map) kwargs_mc2 = { @@ -161,6 +165,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) @@ -230,6 +238,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -944,6 +956,7 @@ def apply( fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -954,7 +967,8 @@ def apply( expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled) + is_torchair=self.torchair_graph_enabled, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1154,6 +1168,9 @@ def forward(self, if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) + + attn_metadata = get_forward_context().attn_metadata + mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None tp_size = get_tensor_model_parallel_world_size() if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: @@ -1171,6 +1188,9 @@ def forward(self, tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] + if mc2_mask is not None: + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: @@ -1209,6 +1229,7 @@ def forward(self, and self.enable_multistream_moe and not is_prefill else None, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + mc2_mask=mc2_mask, ) if shared_experts: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a9938c14f2..b825183ea4 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -215,6 +215,7 @@ def fused_experts_with_mc2( w2_scale_bias: torch.Tensor = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] @@ -232,6 +233,9 @@ def fused_experts_with_mc2( # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num @@ -260,6 +264,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) @@ -310,6 +318,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e3b81a076b..a59bb3a9b4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1666,6 +1666,8 @@ def _dummy_run( attn_metadata.decode.block_table) torch._dynamo.mark_static( attn_metadata.decode.input_positions) + torch._dynamo.mark_static( + attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance( From 8de805a4ac559bfa7ee4b8102cd205dd46ca9b65 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Mon, 7 Jul 2025 14:38:31 +0800 Subject: [PATCH 2/4] add mc2 mask Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 93 ++++++++++++++++-------- vllm_ascend/ops/fused_moe.py | 9 ++- vllm_ascend/quantization/w8a8_dynamic.py | 6 +- vllm_ascend/worker/model_runner_v1.py | 2 + 4 files changed, 77 insertions(+), 33 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 11b079c469..ae305ecc10 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -82,6 +82,8 @@ class ChunkedContextMetadata: max_query_len: int max_seq_lens: int chunked_context: Optional[ChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -95,6 +97,8 @@ class AscendMLADecodeMetadata: seq_lens_list: list[int] actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None mc2_mask: Optional[torch.Tensor] = None @@ -207,9 +211,14 @@ def __init__(self, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - - def generate_active_mask(self, actual_seqs_num, batch_size): - mc2_mask = torch.zeros(batch_size, dtype=torch.bool, device=current_platform.device_type) + self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def generate_activate_mask(self, actual_seqs_num, batch_size): + mc2_mask = torch.zeros(batch_size, + dtype=torch.bool, + device=current_platform.device_type) mc2_mask[:actual_seqs_num].fill_(True) return mc2_mask @@ -343,7 +352,19 @@ def build_torchair_graph_dummy( else: attn_state = AscendAttentionState.DecodeOnly num_decode_tokens = 1 - mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs) + sin = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) + cos = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) + mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -352,8 +373,9 @@ def build_torchair_graph_dummy( max_seq_lens=1, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], - mc2_mask=mc2_mask, - ) + sin=sin, + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -405,6 +427,16 @@ def build( max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() query_start_loc = common_attn_metadata.query_start_loc + if self.cos_cache is None: + self.cos_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.runner.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.runner.dtype) # type: ignore prefill_metadata = None chunked_context_metadata = None @@ -451,18 +483,26 @@ def build( chunk_seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) - + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], - input_positions=input_positions[tokens_start:], + input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, ) decode_metadata = None @@ -507,9 +547,17 @@ def build( actual_seq_q_lens = query_start_loc[1:].tolist( ) + self.runner.actual_seq_q_lens[num_reqs:num_reqs + num_reqs_pad_size] + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) else: seq_lens_list = seq_lens.tolist() - mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs) + cos, sin = None, None + mc2_mask = self.generate_activate_mask( + num_actual_tokens, num_reqs + num_reqs_pad_size) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -519,8 +567,9 @@ def build( max_seq_lens=max_seq_lens, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, - mc2_mask=mc2_mask, - ) + sin=sin, + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -1112,15 +1161,8 @@ def forward( decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - cos = cos[attn_metadata.decode.input_positions] - sin = sin[attn_metadata.decode.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1155,15 +1197,8 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - cos = cos[attn_metadata.prefill.input_positions] - sin = sin[attn_metadata.prefill.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index a3ad96519b..a85877a56f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -138,7 +138,7 @@ def fused_experts_with_mc2( # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) - + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -1168,7 +1168,7 @@ def forward(self, if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) - + attn_metadata = get_forward_context().attn_metadata mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None @@ -1179,6 +1179,9 @@ def forward(self, hidden_states, (0, 0, 0, tp_size - num_tokens)) router_logits = nn.functional.pad( router_logits, (0, 0, 0, tp_size - num_tokens)) + if mc2_mask is not None: + mc2_mask = nn.functional.pad(mc2_mask, + (0, tp_size - num_tokens)) chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, dim=0) @@ -1188,9 +1191,11 @@ def forward(self, tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] + if mc2_mask is not None: chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) mc2_mask = chunk_mc2_mask[tp_rank] + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b825183ea4..3561675095 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -233,7 +233,7 @@ def fused_experts_with_mc2( # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) - + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -803,6 +803,7 @@ def apply( topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -819,7 +820,8 @@ def apply( shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale) + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a59bb3a9b4..12f502a4d7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1666,6 +1666,8 @@ def _dummy_run( attn_metadata.decode.block_table) torch._dynamo.mark_static( attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) torch._dynamo.mark_static( attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) From e782e3de78442deb40510537d0069dcf68746a01 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Tue, 8 Jul 2025 16:52:04 +0800 Subject: [PATCH 3/4] add mc2 mask Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 18 ++---------------- vllm_ascend/worker/model_runner_v1.py | 2 -- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ae305ecc10..5e1be9fad9 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,7 +11,6 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) -from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm_ascend import envs @@ -99,7 +98,6 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None - mc2_mask: Optional[torch.Tensor] = None @dataclass @@ -215,13 +213,6 @@ def __init__(self, self.cos_cache = None self.sin_cache = None - def generate_activate_mask(self, actual_seqs_num, batch_size): - mc2_mask = torch.zeros(batch_size, - dtype=torch.bool, - device=current_platform.device_type) - mc2_mask[:actual_seqs_num].fill_(True) - return mc2_mask - def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at @@ -364,7 +355,6 @@ def build_torchair_graph_dummy( self.rope_dim, dtype=self.runner.dtype, device=device) - mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -374,8 +364,7 @@ def build_torchair_graph_dummy( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], sin=sin, - cos=cos, - mc2_mask=mc2_mask) + cos=cos) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -556,8 +545,6 @@ def build( else: seq_lens_list = seq_lens.tolist() cos, sin = None, None - mc2_mask = self.generate_activate_mask( - num_actual_tokens, num_reqs + num_reqs_pad_size) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -568,8 +555,7 @@ def build( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, sin=sin, - cos=cos, - mc2_mask=mc2_mask) + cos=cos) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 12f502a4d7..11ee1ab526 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1668,8 +1668,6 @@ def _dummy_run( attn_metadata.decode.input_positions) torch._dynamo.mark_static(attn_metadata.decode.sin) torch._dynamo.mark_static(attn_metadata.decode.cos) - torch._dynamo.mark_static( - attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance( From 6782350c353ffcdb0856c25f1cd0bf589b94866c Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Tue, 8 Jul 2025 16:59:13 +0800 Subject: [PATCH 4/4] add mc2 mask Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 18 ++++++++++++++++-- vllm_ascend/worker/model_runner_v1.py | 2 ++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5e1be9fad9..ae305ecc10 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm_ascend import envs @@ -98,6 +99,7 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None + mc2_mask: Optional[torch.Tensor] = None @dataclass @@ -213,6 +215,13 @@ def __init__(self, self.cos_cache = None self.sin_cache = None + def generate_activate_mask(self, actual_seqs_num, batch_size): + mc2_mask = torch.zeros(batch_size, + dtype=torch.bool, + device=current_platform.device_type) + mc2_mask[:actual_seqs_num].fill_(True) + return mc2_mask + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at @@ -355,6 +364,7 @@ def build_torchair_graph_dummy( self.rope_dim, dtype=self.runner.dtype, device=device) + mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -364,7 +374,8 @@ def build_torchair_graph_dummy( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -545,6 +556,8 @@ def build( else: seq_lens_list = seq_lens.tolist() cos, sin = None, None + mc2_mask = self.generate_activate_mask( + num_actual_tokens, num_reqs + num_reqs_pad_size) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -555,7 +568,8 @@ def build( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 11ee1ab526..12f502a4d7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1668,6 +1668,8 @@ def _dummy_run( attn_metadata.decode.input_positions) torch._dynamo.mark_static(attn_metadata.decode.sin) torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static( + attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance(