11
11
from vllm .config import get_current_vllm_config
12
12
from vllm .model_executor .layers .linear import (LinearBase ,
13
13
UnquantizedLinearMethod )
14
+ from vllm .platforms import current_platform
14
15
from vllm .utils import cdiv , round_down
15
16
16
17
from vllm_ascend import envs
@@ -98,6 +99,7 @@ class AscendMLADecodeMetadata:
98
99
attn_mask : Optional [torch .Tensor ] = None
99
100
sin : torch .Tensor = None
100
101
cos : torch .Tensor = None
102
+ mc2_mask : Optional [torch .Tensor ] = None
101
103
102
104
103
105
@dataclass
@@ -213,6 +215,13 @@ def __init__(self,
213
215
self .cos_cache = None
214
216
self .sin_cache = None
215
217
218
+ def generate_activate_mask (self , actual_seqs_num , batch_size ):
219
+ mc2_mask = torch .zeros (batch_size ,
220
+ dtype = torch .bool ,
221
+ device = current_platform .device_type )
222
+ mc2_mask [:actual_seqs_num ].fill_ (True )
223
+ return mc2_mask
224
+
216
225
def reorder_batch (self , input_batch : "InputBatch" ,
217
226
scheduler_output : "SchedulerOutput" ) -> bool :
218
227
# We now want to reorder the batch so that the "decode" requests are at
@@ -355,6 +364,7 @@ def build_torchair_graph_dummy(
355
364
self .rope_dim ,
356
365
dtype = self .runner .dtype ,
357
366
device = device )
367
+ mc2_mask = self .generate_activate_mask (num_actual_tokens , num_reqs )
358
368
decode_metadata = AscendMLADecodeMetadata (
359
369
input_positions = input_positions ,
360
370
block_table = block_table ,
@@ -364,7 +374,8 @@ def build_torchair_graph_dummy(
364
374
attn_mask = self .runner .spec_attn_mask ,
365
375
actual_seq_q_lens = self .runner .actual_seq_q_lens [:num_reqs ],
366
376
sin = sin ,
367
- cos = cos )
377
+ cos = cos ,
378
+ mc2_mask = mc2_mask )
368
379
return self .metadata_cls ( # type: ignore
369
380
num_input_tokens = num_actual_tokens ,
370
381
num_actual_tokens = num_actual_tokens ,
@@ -545,6 +556,8 @@ def build(
545
556
else :
546
557
seq_lens_list = seq_lens .tolist ()
547
558
cos , sin = None , None
559
+ mc2_mask = self .generate_activate_mask (
560
+ num_actual_tokens , num_reqs + num_reqs_pad_size )
548
561
549
562
decode_metadata = AscendMLADecodeMetadata (
550
563
input_positions = input_positions ,
@@ -555,7 +568,8 @@ def build(
555
568
attn_mask = self .runner .spec_attn_mask ,
556
569
actual_seq_q_lens = actual_seq_q_lens ,
557
570
sin = sin ,
558
- cos = cos )
571
+ cos = cos ,
572
+ mc2_mask = mc2_mask )
559
573
560
574
return self .metadata_cls ( # type: ignore
561
575
num_actual_tokens = num_actual_tokens ,
0 commit comments