File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -120,7 +120,10 @@ def fused_experts_with_mc2(
120
120
moe_all_to_all_group_name : Optional [str ] = None ,
121
121
shared_experts : Optional [Any ] = None
122
122
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
123
- global_bs = 0
123
+ vllm_config = get_current_vllm_config ()
124
+ ep_group = get_ep_group ().device_group
125
+ all_to_all_group_size = torch .distributed .get_world_size (ep_group )
126
+ global_bs = vllm_config .scheduler_config .max_num_seqs * all_to_all_group_size
124
127
moe_expert_num = len (expert_map )
125
128
kwargs_mc2 = {
126
129
"x" : hidden_states ,
@@ -132,11 +135,8 @@ def fused_experts_with_mc2(
132
135
}
133
136
134
137
rank = torch .distributed .get_rank ()
135
-
136
138
quant_mode = 0
137
- ep_group = get_ep_group ().device_group
138
139
local_rank = torch .distributed .get_rank (group = ep_group )
139
- all_to_all_group_size = torch .distributed .get_world_size (ep_group )
140
140
141
141
tp_size = get_etp_group ().world_size
142
142
tp_rank = rank % tp_size
@@ -204,7 +204,7 @@ def fused_experts_with_mc2(
204
204
"expert_shard_type" : 0 ,
205
205
"shared_expert_rank_num" : 0 ,
206
206
"moe_expert_num" : moe_expert_num ,
207
- "global_bs" : 0 ,
207
+ "global_bs" : global_bs ,
208
208
}
209
209
tp_recv_counts = output [5 ]
210
210
stage3_kwargs = {
You can’t perform that action at this time.
0 commit comments