Skip to content

Commit c9bbd0c

Browse files
committed
add mc2 mask
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent fd8b9b6 commit c9bbd0c

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def __init__(self,
221221
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
222222

223223
def generate_active_mask(self, actual_seqs_num, batch_size):
224-
mc2_mask = torch.zeros(batch_size, dtype=torch.bool, device=current_platform.device_type)
224+
mc2_mask = torch.zeros(batch_size,
225+
dtype=torch.bool,
226+
device=current_platform.device_type)
225227
mc2_mask[:actual_seqs_num].fill_(True)
226228
return mc2_mask
227229

@@ -521,7 +523,8 @@ def build(
521523
num_reqs_pad_size]
522524
else:
523525
seq_lens_list = seq_lens.tolist()
524-
mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs)
526+
mc2_mask = self.generate_active_mask(
527+
num_actual_tokens, num_reqs + num_reqs_pad_size)
525528

526529
decode_metadata = AscendMLADecodeMetadata(
527530
input_positions=input_positions,

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,9 +1188,11 @@ def forward(self,
11881188
tp_rank = get_tensor_model_parallel_rank()
11891189
hidden_states = chunk_hidden_states[tp_rank]
11901190
router_logits = chunk_router_logits[tp_rank]
1191+
11911192
if mc2_mask is not None:
11921193
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
11931194
mc2_mask = chunk_mc2_mask[tp_rank]
1195+
11941196
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
11951197
# NOTE: When in torchair graph, it has been padded in model_runner_v1
11961198
if not self.torchair_graph_enabled or is_prefill:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ def apply(
803803
topk_weights = topk_weights.to(x.dtype)
804804

805805
if fused_moe_state == FusedMoEState.MC2:
806+
mc2_mask = kwargs.get("mc2_mask", None)
806807
return fused_experts_with_mc2(
807808
hidden_states=x,
808809
w1=layer.w13_weight,
@@ -819,7 +820,8 @@ def apply(
819820
shared_experts=shared_experts,
820821
is_torchair=self.torchair_graph_enabled,
821822
quantized_x_for_share=shared_gate_up,
822-
dynamic_scale_for_share=shared_dequant_scale)
823+
dynamic_scale_for_share=shared_dequant_scale,
824+
mc2_mask=mc2_mask)
823825
elif fused_moe_state == FusedMoEState.AllGather:
824826
return fused_experts(hidden_states=x,
825827
w1=layer.w13_weight,

0 commit comments

Comments
 (0)