Skip to content

Commit 88c31da

Browse files
committed
add mc2 mask
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent 5559443 commit 88c31da

File tree

4 files changed

+47
-1
lines changed

4 files changed

+47
-1
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14+
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv, round_down
1516

1617
from vllm_ascend import envs
@@ -94,6 +95,7 @@ class AscendMLADecodeMetadata:
9495
seq_lens_list: list[int]
9596
actual_seq_q_lens: Optional[list[int]] = None
9697
attn_mask: Optional[torch.Tensor] = None
98+
mc2_mask: Optional[torch.Tensor] = None
9799

98100

99101
@dataclass
@@ -206,6 +208,11 @@ def __init__(self,
206208
ascend_config = get_ascend_config()
207209
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
208210

211+
def generate_active_mask(self, actual_seqs_num, batch_size):
212+
mc2_mask = torch.zeros(batch_size, dtype=torch.bool, device=current_platform.device_type)
213+
mc2_mask[:actual_seqs_num].fill_(True)
214+
return mc2_mask
215+
209216
def reorder_batch(self, input_batch: "InputBatch",
210217
scheduler_output: "SchedulerOutput") -> bool:
211218
# We now want to reorder the batch so that the "decode" requests are at
@@ -336,6 +343,7 @@ def build_torchair_graph_dummy(
336343
else:
337344
attn_state = AscendAttentionState.DecodeOnly
338345
num_decode_tokens = 1
346+
mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs)
339347
decode_metadata = AscendMLADecodeMetadata(
340348
input_positions=input_positions,
341349
block_table=block_table,
@@ -344,6 +352,7 @@ def build_torchair_graph_dummy(
344352
max_seq_lens=1,
345353
attn_mask=self.runner.spec_attn_mask,
346354
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
355+
mc2_mask=mc2_mask,
347356
)
348357
return self.metadata_cls( # type: ignore
349358
num_input_tokens=num_actual_tokens,
@@ -500,6 +509,7 @@ def build(
500509
num_reqs_pad_size]
501510
else:
502511
seq_lens_list = seq_lens.tolist()
512+
mc2_mask = self.generate_active_mask(num_actual_tokens, num_reqs)
503513

504514
decode_metadata = AscendMLADecodeMetadata(
505515
input_positions=input_positions,
@@ -509,6 +519,7 @@ def build(
509519
max_seq_lens=max_seq_lens,
510520
attn_mask=self.runner.spec_attn_mask,
511521
actual_seq_q_lens=actual_seq_q_lens,
522+
mc2_mask=mc2_mask,
512523
)
513524

514525
return self.metadata_cls( # type: ignore

vllm_ascend/ops/fused_moe.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def fused_experts_with_mc2(
122122
moe_all_to_all_group_name: Optional[str] = None,
123123
shared_experts: Optional[Any] = None,
124124
is_torchair: bool = False,
125+
mc2_mask: Optional[torch.Tensor] = None,
125126
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
126127
quant_mode = 0
127128
ep_group = get_ep_group()
@@ -137,6 +138,9 @@ def fused_experts_with_mc2(
137138
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
138139
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
139140
or is_torchair)
141+
142+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
143+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
140144

141145
moe_expert_num = len(expert_map)
142146
kwargs_mc2 = {
@@ -161,6 +165,10 @@ def fused_experts_with_mc2(
161165
"tp_world_size": 1,
162166
"tp_rank_id": 0,
163167
})
168+
if a3_need_extra_args:
169+
stage1_kwargs.update({
170+
"x_active_mask": mc2_mask,
171+
})
164172

165173
kwargs_mc2.update(stage1_kwargs)
166174

@@ -230,6 +238,10 @@ def fused_experts_with_mc2(
230238
"tp_world_size": 1,
231239
"tp_rank_id": 0,
232240
})
241+
if a3_need_extra_args:
242+
stage3_kwargs.update({
243+
"x_active_mask": mc2_mask,
244+
})
233245
kwargs_mc2.update(stage3_kwargs)
234246

235247
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -944,6 +956,7 @@ def apply(
944956

945957
fused_moe_state = get_forward_context().fused_moe_state
946958
if fused_moe_state == FusedMoEState.MC2:
959+
mc2_mask = kwargs.get("mc2_mask", None)
947960
return fused_experts_with_mc2(
948961
hidden_states=x,
949962
w1=layer.w13_weight,
@@ -954,7 +967,8 @@ def apply(
954967
expert_map=expert_map,
955968
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
956969
shared_experts=shared_experts,
957-
is_torchair=self.torchair_graph_enabled)
970+
is_torchair=self.torchair_graph_enabled,
971+
mc2_mask=mc2_mask)
958972
elif fused_moe_state == FusedMoEState.AllGather:
959973
return fused_experts(hidden_states=x,
960974
w1=layer.w13_weight,
@@ -1154,6 +1168,9 @@ def forward(self,
11541168
if shared_experts:
11551169
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11561170
shared_hidden_states = shared_experts(hidden_states)
1171+
1172+
attn_metadata = get_forward_context().attn_metadata
1173+
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
11571174

11581175
tp_size = get_tensor_model_parallel_world_size()
11591176
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
@@ -1171,6 +1188,9 @@ def forward(self,
11711188
tp_rank = get_tensor_model_parallel_rank()
11721189
hidden_states = chunk_hidden_states[tp_rank]
11731190
router_logits = chunk_router_logits[tp_rank]
1191+
if mc2_mask is not None:
1192+
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
1193+
mc2_mask = chunk_mc2_mask[tp_rank]
11741194
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
11751195
# NOTE: When in torchair graph, it has been padded in model_runner_v1
11761196
if not self.torchair_graph_enabled or is_prefill:
@@ -1209,6 +1229,7 @@ def forward(self,
12091229
and self.enable_multistream_moe and not is_prefill else None,
12101230
quantized_x_for_share=quantized_x_for_share,
12111231
dynamic_scale_for_share=dynamic_scale_for_share,
1232+
mc2_mask=mc2_mask,
12121233
)
12131234

12141235
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def fused_experts_with_mc2(
215215
w2_scale_bias: torch.Tensor = None,
216216
quantized_x_for_share: Optional[Any] = None,
217217
dynamic_scale_for_share: Optional[Any] = None,
218+
mc2_mask: Optional[torch.Tensor] = None,
218219
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
219220
if log2phy:
220221
topk_ids = log2phy[topk_ids]
@@ -232,6 +233,9 @@ def fused_experts_with_mc2(
232233
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
233234
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
234235
or is_torchair)
236+
237+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
238+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
235239

236240
if (expert_map is not None):
237241
moe_expert_num = len(expert_map) + global_redundant_expert_num
@@ -260,6 +264,10 @@ def fused_experts_with_mc2(
260264
"tp_world_size": 1,
261265
"tp_rank_id": 0,
262266
})
267+
if a3_need_extra_args:
268+
stage1_kwargs.update({
269+
"x_active_mask": mc2_mask,
270+
})
263271
kwargs_mc2.update(stage1_kwargs)
264272

265273
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -310,6 +318,10 @@ def fused_experts_with_mc2(
310318
"tp_world_size": 1,
311319
"tp_rank_id": 0,
312320
})
321+
if a3_need_extra_args:
322+
stage3_kwargs.update({
323+
"x_active_mask": mc2_mask,
324+
})
313325
kwargs_mc2.update(stage3_kwargs)
314326

315327
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,8 @@ def _dummy_run(
16661666
attn_metadata.decode.block_table)
16671667
torch._dynamo.mark_static(
16681668
attn_metadata.decode.input_positions)
1669+
torch._dynamo.mark_static(
1670+
attn_metadata.decode.mc2_mask)
16691671
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16701672
for kv in self.kv_caches:
16711673
assert isinstance(

0 commit comments

Comments
 (0)