Skip to content

Commit eef1093

Browse files
weiguihua2weijinqian_v1
authored andcommitted
add mc2 mask (#1642)
### What this PR does / why we need it? add mc2 mask ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent ee1dd49 commit eef1093

File tree

4 files changed

+60
-4
lines changed

4 files changed

+60
-4
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14+
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv, round_down
1516

1617
from vllm_ascend import envs
@@ -98,6 +99,7 @@ class AscendMLADecodeMetadata:
9899
attn_mask: Optional[torch.Tensor] = None
99100
sin: torch.Tensor = None
100101
cos: torch.Tensor = None
102+
mc2_mask: Optional[torch.Tensor] = None
101103

102104

103105
@dataclass
@@ -213,6 +215,13 @@ def __init__(self,
213215
self.cos_cache = None
214216
self.sin_cache = None
215217

218+
def generate_activate_mask(self, actual_seqs_num, batch_size):
219+
mc2_mask = torch.zeros(batch_size,
220+
dtype=torch.bool,
221+
device=current_platform.device_type)
222+
mc2_mask[:actual_seqs_num].fill_(True)
223+
return mc2_mask
224+
216225
def reorder_batch(self, input_batch: "InputBatch",
217226
scheduler_output: "SchedulerOutput") -> bool:
218227
# We now want to reorder the batch so that the "decode" requests are at
@@ -355,6 +364,7 @@ def build_torchair_graph_dummy(
355364
self.rope_dim,
356365
dtype=self.runner.dtype,
357366
device=device)
367+
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
358368
decode_metadata = AscendMLADecodeMetadata(
359369
input_positions=input_positions,
360370
block_table=block_table,
@@ -364,7 +374,8 @@ def build_torchair_graph_dummy(
364374
attn_mask=self.runner.spec_attn_mask,
365375
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
366376
sin=sin,
367-
cos=cos)
377+
cos=cos,
378+
mc2_mask=mc2_mask)
368379
return self.metadata_cls( # type: ignore
369380
num_input_tokens=num_actual_tokens,
370381
num_actual_tokens=num_actual_tokens,
@@ -545,6 +556,8 @@ def build(
545556
else:
546557
seq_lens_list = seq_lens.tolist()
547558
cos, sin = None, None
559+
mc2_mask = self.generate_activate_mask(
560+
num_actual_tokens, num_reqs + num_reqs_pad_size)
548561

549562
decode_metadata = AscendMLADecodeMetadata(
550563
input_positions=input_positions,
@@ -555,7 +568,8 @@ def build(
555568
attn_mask=self.runner.spec_attn_mask,
556569
actual_seq_q_lens=actual_seq_q_lens,
557570
sin=sin,
558-
cos=cos)
571+
cos=cos,
572+
mc2_mask=mc2_mask)
559573

560574
return self.metadata_cls( # type: ignore
561575
num_actual_tokens=num_actual_tokens,

vllm_ascend/ops/fused_moe.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def fused_experts_with_mc2(
124124
moe_all_to_all_group_name: Optional[str] = None,
125125
shared_experts: Optional[Any] = None,
126126
is_torchair: bool = False,
127+
mc2_mask: Optional[torch.Tensor] = None,
127128
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
128129
quant_mode = 0
129130
ep_group = get_ep_group()
@@ -140,6 +141,9 @@ def fused_experts_with_mc2(
140141
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
141142
or is_torchair)
142143

144+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
145+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
146+
143147
moe_expert_num = len(expert_map)
144148
kwargs_mc2 = {
145149
"x": hidden_states,
@@ -163,6 +167,10 @@ def fused_experts_with_mc2(
163167
"tp_world_size": 1,
164168
"tp_rank_id": 0,
165169
})
170+
if a3_need_extra_args:
171+
stage1_kwargs.update({
172+
"x_active_mask": mc2_mask,
173+
})
166174

167175
kwargs_mc2.update(stage1_kwargs)
168176

@@ -232,6 +240,10 @@ def fused_experts_with_mc2(
232240
"tp_world_size": 1,
233241
"tp_rank_id": 0,
234242
})
243+
if a3_need_extra_args:
244+
stage3_kwargs.update({
245+
"x_active_mask": mc2_mask,
246+
})
235247
kwargs_mc2.update(stage3_kwargs)
236248

237249
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -958,6 +970,7 @@ def apply(
958970
fused_moe_state = get_forward_context().fused_moe_state
959971

960972
if fused_moe_state == FusedMoEState.MC2:
973+
mc2_mask = kwargs.get("mc2_mask", None)
961974
return fused_experts_with_mc2(
962975
hidden_states=x,
963976
w1=layer.w13_weight,
@@ -968,7 +981,8 @@ def apply(
968981
expert_map=expert_map,
969982
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
970983
shared_experts=shared_experts,
971-
is_torchair=self.torchair_graph_enabled)
984+
is_torchair=self.torchair_graph_enabled,
985+
mc2_mask=mc2_mask)
972986
elif fused_moe_state == FusedMoEState.AllGather:
973987
return fused_experts(hidden_states=x,
974988
w1=layer.w13_weight,
@@ -1194,13 +1208,19 @@ def forward(self,
11941208
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11951209
shared_hidden_states = shared_experts(hidden_states)
11961210

1211+
attn_metadata = get_forward_context().attn_metadata
1212+
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
1213+
11971214
tp_size = get_tensor_model_parallel_world_size()
11981215
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
11991216
if num_tokens < tp_size:
12001217
hidden_states = nn.functional.pad(
12011218
hidden_states, (0, 0, 0, tp_size - num_tokens))
12021219
router_logits = nn.functional.pad(
12031220
router_logits, (0, 0, 0, tp_size - num_tokens))
1221+
if mc2_mask is not None:
1222+
mc2_mask = nn.functional.pad(mc2_mask,
1223+
(0, tp_size - num_tokens))
12041224
chunk_hidden_states = torch.tensor_split(hidden_states,
12051225
tp_size,
12061226
dim=0)
@@ -1210,6 +1230,11 @@ def forward(self,
12101230
tp_rank = get_tensor_model_parallel_rank()
12111231
hidden_states = chunk_hidden_states[tp_rank]
12121232
router_logits = chunk_router_logits[tp_rank]
1233+
1234+
if mc2_mask is not None:
1235+
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
1236+
mc2_mask = chunk_mc2_mask[tp_rank]
1237+
12131238
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
12141239
# NOTE: When in torchair graph, it has been padded in model_runner_v1
12151240
if not self.torchair_graph_enabled or is_prefill:
@@ -1248,6 +1273,7 @@ def forward(self,
12481273
and self.enable_multistream_moe and not is_prefill else None,
12491274
quantized_x_for_share=quantized_x_for_share,
12501275
dynamic_scale_for_share=dynamic_scale_for_share,
1276+
mc2_mask=mc2_mask,
12511277
token_dispatcher=self.token_dispatcher
12521278
)
12531279

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def fused_experts_with_mc2(
215215
w2_scale_bias: torch.Tensor = None,
216216
quantized_x_for_share: Optional[Any] = None,
217217
dynamic_scale_for_share: Optional[Any] = None,
218+
mc2_mask: Optional[torch.Tensor] = None,
218219
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
219220
if log2phy:
220221
topk_ids = log2phy[topk_ids]
@@ -233,6 +234,9 @@ def fused_experts_with_mc2(
233234
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
234235
or is_torchair)
235236

237+
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
238+
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
239+
236240
if (expert_map is not None):
237241
moe_expert_num = len(expert_map) + global_redundant_expert_num
238242
else:
@@ -260,6 +264,10 @@ def fused_experts_with_mc2(
260264
"tp_world_size": 1,
261265
"tp_rank_id": 0,
262266
})
267+
if a3_need_extra_args:
268+
stage1_kwargs.update({
269+
"x_active_mask": mc2_mask,
270+
})
263271
kwargs_mc2.update(stage1_kwargs)
264272

265273
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
@@ -310,6 +318,10 @@ def fused_experts_with_mc2(
310318
"tp_world_size": 1,
311319
"tp_rank_id": 0,
312320
})
321+
if a3_need_extra_args:
322+
stage3_kwargs.update({
323+
"x_active_mask": mc2_mask,
324+
})
313325
kwargs_mc2.update(stage3_kwargs)
314326

315327
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
@@ -791,6 +803,7 @@ def apply(
791803
topk_weights = topk_weights.to(x.dtype)
792804

793805
if fused_moe_state == FusedMoEState.MC2:
806+
mc2_mask = kwargs.get("mc2_mask", None)
794807
return fused_experts_with_mc2(
795808
hidden_states=x,
796809
w1=layer.w13_weight,
@@ -807,7 +820,8 @@ def apply(
807820
shared_experts=shared_experts,
808821
is_torchair=self.torchair_graph_enabled,
809822
quantized_x_for_share=shared_gate_up,
810-
dynamic_scale_for_share=shared_dequant_scale)
823+
dynamic_scale_for_share=shared_dequant_scale,
824+
mc2_mask=mc2_mask)
811825
elif fused_moe_state == FusedMoEState.AllGather:
812826
return fused_experts(hidden_states=x,
813827
w1=layer.w13_weight,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,8 @@ def _dummy_run(
16681668
attn_metadata.decode.input_positions)
16691669
torch._dynamo.mark_static(attn_metadata.decode.sin)
16701670
torch._dynamo.mark_static(attn_metadata.decode.cos)
1671+
torch._dynamo.mark_static(
1672+
attn_metadata.decode.mc2_mask)
16711673
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16721674
for kv in self.kv_caches:
16731675
assert isinstance(

0 commit comments

Comments
 (0)