Skip to content

Commit 31208b4

Browse files
authored
add mc2 mask (#1642)
### What this PR does / why we need it? add mc2 mask ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent e99d232 commit 31208b4

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14+
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv, round_down
1516

1617
from vllm_ascend import envs
@@ -98,6 +99,7 @@ class AscendMLADecodeMetadata:
9899
attn_mask: Optional[torch.Tensor] = None
99100
sin: torch.Tensor = None
100101
cos: torch.Tensor = None
102+
mc2_mask: Optional[torch.Tensor] = None
101103

102104

103105
@dataclass
@@ -213,6 +215,13 @@ def __init__(self,
213215
self.cos_cache = None
214216
self.sin_cache = None
215217

218+
def generate_activate_mask(self, actual_seqs_num, batch_size):
219+
mc2_mask = torch.zeros(batch_size,
220+
dtype=torch.bool,
221+
device=current_platform.device_type)
222+
mc2_mask[:actual_seqs_num].fill_(True)
223+
return mc2_mask
224+
216225
def reorder_batch(self, input_batch: "InputBatch",
217226
scheduler_output: "SchedulerOutput") -> bool:
218227
# We now want to reorder the batch so that the "decode" requests are at
@@ -355,6 +364,7 @@ def build_torchair_graph_dummy(
355364
self.rope_dim,
356365
dtype=self.runner.dtype,
357366
device=device)
367+
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
358368
decode_metadata = AscendMLADecodeMetadata(
359369
input_positions=input_positions,
360370
block_table=block_table,
@@ -364,7 +374,8 @@ def build_torchair_graph_dummy(
364374
attn_mask=self.runner.spec_attn_mask,
365375
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
366376
sin=sin,
367-
cos=cos)
377+
cos=cos,
378+
mc2_mask=mc2_mask)
368379
return self.metadata_cls( # type: ignore
369380
num_input_tokens=num_actual_tokens,
370381
num_actual_tokens=num_actual_tokens,
@@ -545,6 +556,8 @@ def build(
545556
else:
546557
seq_lens_list = seq_lens.tolist()
547558
cos, sin = None, None
559+
mc2_mask = self.generate_activate_mask(
560+
num_actual_tokens, num_reqs + num_reqs_pad_size)
548561

549562
decode_metadata = AscendMLADecodeMetadata(
550563
input_positions=input_positions,
@@ -555,7 +568,8 @@ def build(
555568
attn_mask=self.runner.spec_attn_mask,
556569
actual_seq_q_lens=actual_seq_q_lens,
557570
sin=sin,
558-
cos=cos)
571+
cos=cos,
572+
mc2_mask=mc2_mask)
559573

560574
return self.metadata_cls( # type: ignore
561575
num_actual_tokens=num_actual_tokens,

vllm_ascend/ops/fused_moe.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def fused_experts_with_mc2(
122122
moe_all_to_all_group_name: Optional[str] = None,
123123
shared_experts: Optional[Any] = None,
124124
is_torchair: bool = False,
125+
mc2_mask: Optional[torch.Tensor] = None,
125126
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
126127
quant_mode = 0
127128
ep_group = get_ep_group()
@@ -138,6 +139,9 @@ def fused_experts_with_mc2(
138139
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
139140
or is_torchair)
140141

142+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
143+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
144+
141145
moe_expert_num = len(expert_map)
142146
kwargs_mc2 = {
143147
"x": hidden_states,
@@ -161,6 +165,10 @@ def fused_experts_with_mc2(
161165
"tp_world_size": 1,
162166
"tp_rank_id": 0,
163167
})
168+
if a3_need_extra_args:
169+
stage1_kwargs.update({
170+
"x_active_mask": mc2_mask,
171+
})
164172

165173
kwargs_mc2.update(stage1_kwargs)
166174

@@ -230,6 +238,10 @@ def fused_experts_with_mc2(
230238
"tp_world_size": 1,
231239
"tp_rank_id": 0,
232240
})
241+
if a3_need_extra_args:
242+
stage3_kwargs.update({
243+
"x_active_mask": mc2_mask,
244+
})
233245
kwargs_mc2.update(stage3_kwargs)
234246

235247
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -944,6 +956,7 @@ def apply(
944956

945957
fused_moe_state = get_forward_context().fused_moe_state
946958
if fused_moe_state == FusedMoEState.MC2:
959+
mc2_mask = kwargs.get("mc2_mask", None)
947960
return fused_experts_with_mc2(
948961
hidden_states=x,
949962
w1=layer.w13_weight,
@@ -954,7 +967,8 @@ def apply(
954967
expert_map=expert_map,
955968
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
956969
shared_experts=shared_experts,
957-
is_torchair=self.torchair_graph_enabled)
970+
is_torchair=self.torchair_graph_enabled,
971+
mc2_mask=mc2_mask)
958972
elif fused_moe_state == FusedMoEState.AllGather:
959973
return fused_experts(hidden_states=x,
960974
w1=layer.w13_weight,
@@ -1155,13 +1169,19 @@ def forward(self,
11551169
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11561170
shared_hidden_states = shared_experts(hidden_states)
11571171

1172+
attn_metadata = get_forward_context().attn_metadata
1173+
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
1174+
11581175
tp_size = get_tensor_model_parallel_world_size()
11591176
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
11601177
if num_tokens < tp_size:
11611178
hidden_states = nn.functional.pad(
11621179
hidden_states, (0, 0, 0, tp_size - num_tokens))
11631180
router_logits = nn.functional.pad(
11641181
router_logits, (0, 0, 0, tp_size - num_tokens))
1182+
if mc2_mask is not None:
1183+
mc2_mask = nn.functional.pad(mc2_mask,
1184+
(0, tp_size - num_tokens))
11651185
chunk_hidden_states = torch.tensor_split(hidden_states,
11661186
tp_size,
11671187
dim=0)
@@ -1171,6 +1191,11 @@ def forward(self,
11711191
tp_rank = get_tensor_model_parallel_rank()
11721192
hidden_states = chunk_hidden_states[tp_rank]
11731193
router_logits = chunk_router_logits[tp_rank]
1194+
1195+
if mc2_mask is not None:
1196+
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
1197+
mc2_mask = chunk_mc2_mask[tp_rank]
1198+
11741199
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
11751200
# NOTE: When in torchair graph, it has been padded in model_runner_v1
11761201
if not self.torchair_graph_enabled or is_prefill:
@@ -1209,6 +1234,7 @@ def forward(self,
12091234
and self.enable_multistream_moe and not is_prefill else None,
12101235
quantized_x_for_share=quantized_x_for_share,
12111236
dynamic_scale_for_share=dynamic_scale_for_share,
1237+
mc2_mask=mc2_mask,
12121238
)
12131239

12141240
if shared_experts:

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def fused_experts_with_mc2(
215215
w2_scale_bias: torch.Tensor = None,
216216
quantized_x_for_share: Optional[Any] = None,
217217
dynamic_scale_for_share: Optional[Any] = None,
218+
mc2_mask: Optional[torch.Tensor] = None,
218219
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
219220
if log2phy:
220221
topk_ids = log2phy[topk_ids]
@@ -233,6 +234,9 @@ def fused_experts_with_mc2(
233234
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
234235
or is_torchair)
235236

237+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
238+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
239+
236240
if (expert_map is not None):
237241
moe_expert_num = len(expert_map) + global_redundant_expert_num
238242
else:
@@ -260,6 +264,10 @@ def fused_experts_with_mc2(
260264
"tp_world_size": 1,
261265
"tp_rank_id": 0,
262266
})
267+
if a3_need_extra_args:
268+
stage1_kwargs.update({
269+
"x_active_mask": mc2_mask,
270+
})
263271
kwargs_mc2.update(stage1_kwargs)
264272

265273
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -310,6 +318,10 @@ def fused_experts_with_mc2(
310318
"tp_world_size": 1,
311319
"tp_rank_id": 0,
312320
})
321+
if a3_need_extra_args:
322+
stage3_kwargs.update({
323+
"x_active_mask": mc2_mask,
324+
})
313325
kwargs_mc2.update(stage3_kwargs)
314326

315327
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -791,6 +803,7 @@ def apply(
791803
topk_weights = topk_weights.to(x.dtype)
792804

793805
if fused_moe_state == FusedMoEState.MC2:
806+
mc2_mask = kwargs.get("mc2_mask", None)
794807
return fused_experts_with_mc2(
795808
hidden_states=x,
796809
w1=layer.w13_weight,
@@ -807,7 +820,8 @@ def apply(
807820
shared_experts=shared_experts,
808821
is_torchair=self.torchair_graph_enabled,
809822
quantized_x_for_share=shared_gate_up,
810-
dynamic_scale_for_share=shared_dequant_scale)
823+
dynamic_scale_for_share=shared_dequant_scale,
824+
mc2_mask=mc2_mask)
811825
elif fused_moe_state == FusedMoEState.AllGather:
812826
return fused_experts(hidden_states=x,
813827
w1=layer.w13_weight,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,8 @@ def _dummy_run(
16681668
attn_metadata.decode.input_positions)
16691669
torch._dynamo.mark_static(attn_metadata.decode.sin)
16701670
torch._dynamo.mark_static(attn_metadata.decode.cos)
1671+
torch._dynamo.mark_static(
1672+
attn_metadata.decode.mc2_mask)
16711673
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16721674
for kv in self.kv_caches:
16731675
assert isinstance(

0 commit comments

Comments
 (0)