Skip to content

Commit c1979f4

Browse files
committed
add mc2 mask
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent ef27370 commit c1979f4

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ def build(
523523
num_reqs_pad_size]
524524
else:
525525
seq_lens_list = seq_lens.tolist()
526-
mc2_mask = self.generate_active_mask(
527-
num_actual_tokens, num_reqs + num_reqs_pad_size)
526+
mc2_mask = self.generate_active_mask(num_actual_tokens,
527+
num_reqs + num_reqs_pad_size)
528528

529529
decode_metadata = AscendMLADecodeMetadata(
530530
input_positions=input_positions,

vllm_ascend/ops/fused_moe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fused_experts_with_mc2(
138138
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
139139
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
140140
or is_torchair)
141-
141+
142142
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
143143
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
144144

@@ -1168,7 +1168,7 @@ def forward(self,
11681168
if shared_experts:
11691169
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
11701170
shared_hidden_states = shared_experts(hidden_states)
1171-
1171+
11721172
attn_metadata = get_forward_context().attn_metadata
11731173
mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None
11741174

@@ -1180,8 +1180,8 @@ def forward(self,
11801180
router_logits = nn.functional.pad(
11811181
router_logits, (0, 0, 0, tp_size - num_tokens))
11821182
if mc2_mask is not None:
1183-
mc2_mask = nn.functional.pad(
1184-
mc2_mask, (0, tp_size - num_tokens))
1183+
mc2_mask = nn.functional.pad(mc2_mask,
1184+
(0, tp_size - num_tokens))
11851185
chunk_hidden_states = torch.tensor_split(hidden_states,
11861186
tp_size,
11871187
dim=0)

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def fused_experts_with_mc2(
233233
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
234234
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
235235
or is_torchair)
236-
236+
237237
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
238238
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
239239

0 commit comments

Comments
 (0)