Skip to content

add mc2 mask #1642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down

from vllm_ascend import envs
Expand Down Expand Up @@ -98,6 +99,7 @@ class AscendMLADecodeMetadata:
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
mc2_mask: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -213,6 +215,13 @@ def __init__(self,
self.cos_cache = None
self.sin_cache = None

def generate_activate_mask(self, actual_seqs_num, batch_size):
mc2_mask = torch.zeros(batch_size,
dtype=torch.bool,
device=current_platform.device_type)
mc2_mask[:actual_seqs_num].fill_(True)
return mc2_mask

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
Expand Down Expand Up @@ -355,6 +364,7 @@ def build_torchair_graph_dummy(
self.rope_dim,
dtype=self.runner.dtype,
device=device)
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
Expand All @@ -364,7 +374,8 @@ def build_torchair_graph_dummy(
attn_mask=self.runner.spec_attn_mask,
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
sin=sin,
cos=cos)
cos=cos,
mc2_mask=mc2_mask)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -545,6 +556,8 @@ def build(
else:
seq_lens_list = seq_lens.tolist()
cos, sin = None, None
mc2_mask = self.generate_activate_mask(
num_actual_tokens, num_reqs + num_reqs_pad_size)

decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
Expand All @@ -555,7 +568,8 @@ def build(
attn_mask=self.runner.spec_attn_mask,
actual_seq_q_lens=actual_seq_q_lens,
sin=sin,
cos=cos)
cos=cos,
mc2_mask=mc2_mask)

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand Down
28 changes: 27 additions & 1 deletion vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def fused_experts_with_mc2(
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
quant_mode = 0
ep_group = get_ep_group()
Expand All @@ -138,6 +139,9 @@ def fused_experts_with_mc2(
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)

# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3

moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
Expand All @@ -161,6 +165,10 @@ def fused_experts_with_mc2(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})

kwargs_mc2.update(stage1_kwargs)

Expand Down Expand Up @@ -230,6 +238,10 @@ def fused_experts_with_mc2(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
Expand Down Expand Up @@ -944,6 +956,7 @@ def apply(

fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.MC2:
mc2_mask = kwargs.get("mc2_mask", None)
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -954,7 +967,8 @@ def apply(
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled)
is_torchair=self.torchair_graph_enabled,
mc2_mask=mc2_mask)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down Expand Up @@ -1155,13 +1169,19 @@ def forward(self,
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)

attn_metadata = get_forward_context().attn_metadata
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None

tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
if mc2_mask is not None:
mc2_mask = nn.functional.pad(mc2_mask,
(0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
Expand All @@ -1171,6 +1191,11 @@ def forward(self,
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]

if mc2_mask is not None:
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]

if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled or is_prefill:
Expand Down Expand Up @@ -1209,6 +1234,7 @@ def forward(self,
and self.enable_multistream_moe and not is_prefill else None,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
)

if shared_experts:
Expand Down
16 changes: 15 additions & 1 deletion vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def fused_experts_with_mc2(
w2_scale_bias: torch.Tensor = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if log2phy:
topk_ids = log2phy[topk_ids]
Expand All @@ -233,6 +234,9 @@ def fused_experts_with_mc2(
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)

# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3

if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
Expand Down Expand Up @@ -260,6 +264,10 @@ def fused_experts_with_mc2(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)

output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
Expand Down Expand Up @@ -310,6 +318,10 @@ def fused_experts_with_mc2(
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)

hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
Expand Down Expand Up @@ -791,6 +803,7 @@ def apply(
topk_weights = topk_weights.to(x.dtype)

if fused_moe_state == FusedMoEState.MC2:
mc2_mask = kwargs.get("mc2_mask", None)
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -807,7 +820,8 @@ def apply(
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
quantized_x_for_share=shared_gate_up,
dynamic_scale_for_share=shared_dequant_scale)
dynamic_scale_for_share=shared_dequant_scale,
mc2_mask=mc2_mask)
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
Expand Down
2 changes: 2 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,8 @@ def _dummy_run(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(
attn_metadata.decode.mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
Expand Down