@@ -208,7 +208,7 @@ def __init__(self,
208
208
ascend_config = get_ascend_config ()
209
209
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
210
210
211
- def generate_active_mask (self , actual_seqs_num , batch_size ):
211
+ def generate_activate_mask (self , actual_seqs_num , batch_size ):
212
212
mc2_mask = torch .zeros (batch_size ,
213
213
dtype = torch .bool ,
214
214
device = current_platform .device_type )
@@ -345,7 +345,7 @@ def build_torchair_graph_dummy(
345
345
else :
346
346
attn_state = AscendAttentionState .DecodeOnly
347
347
num_decode_tokens = 1
348
- mc2_mask = self .generate_active_mask (num_actual_tokens , num_reqs )
348
+ mc2_mask = self .generate_activate_mask (num_actual_tokens , num_reqs )
349
349
decode_metadata = AscendMLADecodeMetadata (
350
350
input_positions = input_positions ,
351
351
block_table = block_table ,
@@ -511,7 +511,7 @@ def build(
511
511
num_reqs_pad_size ]
512
512
else :
513
513
seq_lens_list = seq_lens .tolist ()
514
- mc2_mask = self .generate_active_mask (num_actual_tokens ,
514
+ mc2_mask = self .generate_activate_mask (num_actual_tokens ,
515
515
num_reqs + num_reqs_pad_size )
516
516
517
517
decode_metadata = AscendMLADecodeMetadata (
0 commit comments