@@ -180,12 +180,6 @@ def __init__(
180
180
else :
181
181
self .gate .e_score_correction_bias = None
182
182
183
- self .enable_cv_parallel = False
184
- additional_config = get_current_vllm_config ().additional_config
185
- if additional_config :
186
- self .enable_cv_parallel = additional_config .get (
187
- "enable_cv_parallel" , False )
188
-
189
183
self .experts = AscendFusedMoE (
190
184
num_experts = config .n_routed_experts ,
191
185
top_k = config .num_experts_per_tok ,
@@ -222,10 +216,13 @@ def __init__(
222
216
self .params_dtype = torch .get_default_dtype ()
223
217
224
218
self .enable_graph_mode = False
219
+ self .enable_multistream_shared_expert = False
225
220
additional_config = get_current_vllm_config ().additional_config
226
221
if additional_config :
227
222
self .enable_graph_mode = additional_config .get (
228
223
"enable_graph_mode" , False )
224
+ self .enable_multistream_shared_expert = additional_config .get (
225
+ "enable_multistream_shared_expert" , False )
229
226
230
227
def forward (
231
228
self ,
@@ -248,10 +245,10 @@ def forward(
248
245
249
246
num_tokens , hidden_size = hidden_states .shape
250
247
251
- cv_parallel = self .enable_cv_parallel and not is_prefill
248
+ multistream = self .enable_multistream_shared_expert and not is_prefill
252
249
253
250
if self .n_shared_experts is not None :
254
- if not cv_parallel :
251
+ if not multistream :
255
252
shared_output = self .shared_experts (hidden_states )
256
253
else :
257
254
shared_hidden_states = hidden_states
@@ -275,41 +272,25 @@ def forward(
275
272
# router_logits: (num_tokens, n_experts)
276
273
router_logits , _ = self .gate (hidden_states )
277
274
278
- if self .n_shared_experts is not None and cv_parallel :
279
- with tng .scope .npu_stream_switch ('cv' ):
280
- tng .scope .npu_wait_tensor (shared_hidden_states , router_logits )
281
- dynamic_scale = None
282
- if self .shared_experts .is_dynamic_quant :
283
- x , dynamic_scale = torch_npu .npu_dynamic_quant (
284
- shared_hidden_states )
285
- gate_up = torch_npu .npu_quant_matmul (
286
- x ,
287
- self .shared_experts .gate_up_proj .weight ,
288
- self .shared_experts .gate_up_proj .weight_scale ,
289
- output_dtype = torch .int32 ,
290
- )
291
- else :
292
- gate_up , _ = self .gate_up_proj (shared_hidden_states )
293
-
294
- if cv_parallel :
295
- hidden_states , shared_output = self .experts (
296
- hidden_states = hidden_states ,
297
- router_logits = router_logits ,
298
- is_prefill = is_prefill ,
299
- top_k = CustomDeepseekV2MoE .top_k ,
300
- enable_force_load_balance = enable_force_load_balance ,
301
- shared_experts = self .shared_experts ,
302
- shared_gate_up = gate_up ,
303
- shared_dynamic_scale = dynamic_scale )
304
- hidden_states = hidden_states * self .routed_scaling_factor
305
- else :
306
- hidden_states = self .experts (
307
- hidden_states = hidden_states ,
308
- router_logits = router_logits ,
309
- is_prefill = is_prefill ,
310
- top_k = CustomDeepseekV2MoE .top_k ,
311
- enable_force_load_balance = enable_force_load_balance ,
312
- ) * self .routed_scaling_factor
275
+ kwargs = {}
276
+ if multistream :
277
+ kwargs .update ({
278
+ "shared_experts" : self .shared_experts ,
279
+ "shared_hidden_states" : shared_hidden_states
280
+ })
281
+
282
+ hidden_states = self .experts (
283
+ hidden_states = hidden_states ,
284
+ router_logits = router_logits ,
285
+ is_prefill = is_prefill ,
286
+ top_k = CustomDeepseekV2MoE .top_k ,
287
+ enable_force_load_balance = enable_force_load_balance ,
288
+ ** kwargs )
289
+
290
+ if multistream :
291
+ hidden_states , shared_output = hidden_states
292
+
293
+ hidden_states = hidden_states * self .routed_scaling_factor
313
294
314
295
if self .tp_size > 1 :
315
296
if self .enable_graph_mode :
0 commit comments