File tree Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Expand file tree Collapse file tree 3 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -221,7 +221,9 @@ def __init__(self,
221
221
self .torchair_graph_enabled = ascend_config .torchair_graph_config .enabled
222
222
223
223
def generate_active_mask (self , actual_seqs_num , batch_size ):
224
- mc2_mask = torch .zeros (batch_size , dtype = torch .bool , device = current_platform .device_type )
224
+ mc2_mask = torch .zeros (batch_size ,
225
+ dtype = torch .bool ,
226
+ device = current_platform .device_type )
225
227
mc2_mask [:actual_seqs_num ].fill_ (True )
226
228
return mc2_mask
227
229
@@ -521,7 +523,8 @@ def build(
521
523
num_reqs_pad_size ]
522
524
else :
523
525
seq_lens_list = seq_lens .tolist ()
524
- mc2_mask = self .generate_active_mask (num_actual_tokens , num_reqs )
526
+ mc2_mask = self .generate_active_mask (
527
+ num_actual_tokens , num_reqs + num_reqs_pad_size )
525
528
526
529
decode_metadata = AscendMLADecodeMetadata (
527
530
input_positions = input_positions ,
Original file line number Diff line number Diff line change @@ -1188,9 +1188,11 @@ def forward(self,
1188
1188
tp_rank = get_tensor_model_parallel_rank ()
1189
1189
hidden_states = chunk_hidden_states [tp_rank ]
1190
1190
router_logits = chunk_router_logits [tp_rank ]
1191
+
1191
1192
if mc2_mask is not None :
1192
1193
chunk_mc2_mask = torch .tensor_split (mc2_mask , tp_size , dim = 0 )
1193
1194
mc2_mask = chunk_mc2_mask [tp_rank ]
1195
+
1194
1196
if self .dp_size > 1 and fused_moe_state == FusedMoEState .AllGather :
1195
1197
# NOTE: When in torchair graph, it has been padded in model_runner_v1
1196
1198
if not self .torchair_graph_enabled or is_prefill :
Original file line number Diff line number Diff line change @@ -803,6 +803,7 @@ def apply(
803
803
topk_weights = topk_weights .to (x .dtype )
804
804
805
805
if fused_moe_state == FusedMoEState .MC2 :
806
+ mc2_mask = kwargs .get ("mc2_mask" , None )
806
807
return fused_experts_with_mc2 (
807
808
hidden_states = x ,
808
809
w1 = layer .w13_weight ,
@@ -819,7 +820,8 @@ def apply(
819
820
shared_experts = shared_experts ,
820
821
is_torchair = self .torchair_graph_enabled ,
821
822
quantized_x_for_share = shared_gate_up ,
822
- dynamic_scale_for_share = shared_dequant_scale )
823
+ dynamic_scale_for_share = shared_dequant_scale ,
824
+ mc2_mask = mc2_mask )
823
825
elif fused_moe_state == FusedMoEState .AllGather :
824
826
return fused_experts (hidden_states = x ,
825
827
w1 = layer .w13_weight ,
You can’t perform that action at this time.
0 commit comments