Skip to content

Commit 8ca93a8

Browse files
committed
add mc2 mask
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
1 parent 7d8a6f4 commit 8ca93a8

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.config import get_current_vllm_config
1212
from vllm.model_executor.layers.linear import (LinearBase,
1313
UnquantizedLinearMethod)
14+
from vllm.platforms import current_platform
1415
from vllm.utils import cdiv, round_down
1516

1617
from vllm_ascend import envs
@@ -98,6 +99,7 @@ class AscendMLADecodeMetadata:
9899
attn_mask: Optional[torch.Tensor] = None
99100
sin: torch.Tensor = None
100101
cos: torch.Tensor = None
102+
mc2_mask: Optional[torch.Tensor] = None
101103

102104

103105
@dataclass
@@ -213,6 +215,13 @@ def __init__(self,
213215
self.cos_cache = None
214216
self.sin_cache = None
215217

218+
def generate_activate_mask(self, actual_seqs_num, batch_size):
219+
mc2_mask = torch.zeros(batch_size,
220+
dtype=torch.bool,
221+
device=current_platform.device_type)
222+
mc2_mask[:actual_seqs_num].fill_(True)
223+
return mc2_mask
224+
216225
def reorder_batch(self, input_batch: "InputBatch",
217226
scheduler_output: "SchedulerOutput") -> bool:
218227
# We now want to reorder the batch so that the "decode" requests are at
@@ -355,6 +364,7 @@ def build_torchair_graph_dummy(
355364
self.rope_dim,
356365
dtype=self.runner.dtype,
357366
device=device)
367+
mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs)
358368
decode_metadata = AscendMLADecodeMetadata(
359369
input_positions=input_positions,
360370
block_table=block_table,
@@ -364,7 +374,8 @@ def build_torchair_graph_dummy(
364374
attn_mask=self.runner.spec_attn_mask,
365375
actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs],
366376
sin=sin,
367-
cos=cos)
377+
cos=cos,
378+
mc2_mask=mc2_mask)
368379
return self.metadata_cls( # type: ignore
369380
num_input_tokens=num_actual_tokens,
370381
num_actual_tokens=num_actual_tokens,
@@ -545,6 +556,8 @@ def build(
545556
else:
546557
seq_lens_list = seq_lens.tolist()
547558
cos, sin = None, None
559+
mc2_mask = self.generate_activate_mask(
560+
num_actual_tokens, num_reqs + num_reqs_pad_size)
548561

549562
decode_metadata = AscendMLADecodeMetadata(
550563
input_positions=input_positions,
@@ -555,7 +568,8 @@ def build(
555568
attn_mask=self.runner.spec_attn_mask,
556569
actual_seq_q_lens=actual_seq_q_lens,
557570
sin=sin,
558-
cos=cos)
571+
cos=cos,
572+
mc2_mask=mc2_mask)
559573

560574
return self.metadata_cls( # type: ignore
561575
num_actual_tokens=num_actual_tokens,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,6 +1668,8 @@ def _dummy_run(
16681668
attn_metadata.decode.input_positions)
16691669
torch._dynamo.mark_static(attn_metadata.decode.sin)
16701670
torch._dynamo.mark_static(attn_metadata.decode.cos)
1671+
torch._dynamo.mark_static(
1672+
attn_metadata.decode.mc2_mask)
16711673
torch._dynamo.mark_static(attn_metadata.slot_mapping)
16721674
for kv in self.kv_caches:
16731675
assert isinstance(

0 commit comments

Comments
 (0)