@@ -118,9 +118,13 @@ def fused_experts_with_mc2(
118
118
top_k : int ,
119
119
expert_map : torch .Tensor = None ,
120
120
moe_all_to_all_group_name : Optional [str ] = None ,
121
- shared_experts : Optional [Any ] = None
121
+ shared_experts : Optional [Any ] = None ,
122
+ global_batch_size : int = 256 ,
122
123
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
123
- global_bs = 0
124
+
125
+ ep_group = get_ep_group ().device_group
126
+ all_to_all_group_size = torch .distributed .get_world_size (ep_group )
127
+ global_bs = global_batch_size * all_to_all_group_size
124
128
moe_expert_num = len (expert_map )
125
129
kwargs_mc2 = {
126
130
"x" : hidden_states ,
@@ -132,11 +136,8 @@ def fused_experts_with_mc2(
132
136
}
133
137
134
138
rank = torch .distributed .get_rank ()
135
-
136
139
quant_mode = 0
137
- ep_group = get_ep_group ().device_group
138
140
local_rank = torch .distributed .get_rank (group = ep_group )
139
- all_to_all_group_size = torch .distributed .get_world_size (ep_group )
140
141
141
142
tp_size = get_etp_group ().world_size
142
143
tp_rank = rank % tp_size
@@ -204,7 +205,7 @@ def fused_experts_with_mc2(
204
205
"expert_shard_type" : 0 ,
205
206
"shared_expert_rank_num" : 0 ,
206
207
"moe_expert_num" : moe_expert_num ,
207
- "global_bs" : 0 ,
208
+ "global_bs" : global_bs ,
208
209
}
209
210
tp_recv_counts = output [5 ]
210
211
stage3_kwargs = {
@@ -1037,7 +1038,8 @@ def apply(
1037
1038
top_k = top_k ,
1038
1039
expert_map = expert_map ,
1039
1040
moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
1040
- shared_experts = shared_experts )
1041
+ shared_experts = shared_experts ,
1042
+ global_batch_size = self .global_batch_size )
1041
1043
elif fused_moe_state == FusedMoEState .AllGather :
1042
1044
return fused_experts (hidden_states = x ,
1043
1045
w1 = layer .w13_weight ,
0 commit comments