Skip to content

Commit 3511331

Browse files
committed
refactor in deepseek moe
Signed-off-by: David9857 <985700846@qq.com>
1 parent 4fa61d5 commit 3511331

File tree

3 files changed

+47
-48
lines changed

3 files changed

+47
-48
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,6 @@ def __init__(
180180
else:
181181
self.gate.e_score_correction_bias = None
182182

183-
self.enable_cv_parallel = False
184-
additional_config = get_current_vllm_config().additional_config
185-
if additional_config:
186-
self.enable_cv_parallel = additional_config.get(
187-
"enable_cv_parallel", False)
188-
189183
self.experts = AscendFusedMoE(
190184
num_experts=config.n_routed_experts,
191185
top_k=config.num_experts_per_tok,
@@ -222,10 +216,13 @@ def __init__(
222216
self.params_dtype = torch.get_default_dtype()
223217

224218
self.enable_graph_mode = False
219+
self.enable_multistream_shared_expert = False
225220
additional_config = get_current_vllm_config().additional_config
226221
if additional_config:
227222
self.enable_graph_mode = additional_config.get(
228223
"enable_graph_mode", False)
224+
self.enable_multistream_shared_expert = additional_config.get(
225+
"enable_multistream_shared_expert", False)
229226

230227
def forward(
231228
self,
@@ -248,10 +245,10 @@ def forward(
248245

249246
num_tokens, hidden_size = hidden_states.shape
250247

251-
cv_parallel = self.enable_cv_parallel and not is_prefill
248+
multistream = self.enable_multistream_shared_expert and not is_prefill
252249

253250
if self.n_shared_experts is not None:
254-
if not cv_parallel:
251+
if not multistream:
255252
shared_output = self.shared_experts(hidden_states)
256253
else:
257254
shared_hidden_states = hidden_states
@@ -275,41 +272,25 @@ def forward(
275272
# router_logits: (num_tokens, n_experts)
276273
router_logits, _ = self.gate(hidden_states)
277274

278-
if self.n_shared_experts is not None and cv_parallel:
279-
with tng.scope.npu_stream_switch('cv'):
280-
tng.scope.npu_wait_tensor(shared_hidden_states, router_logits)
281-
dynamic_scale = None
282-
if self.shared_experts.is_dynamic_quant:
283-
x, dynamic_scale = torch_npu.npu_dynamic_quant(
284-
shared_hidden_states)
285-
gate_up = torch_npu.npu_quant_matmul(
286-
x,
287-
self.shared_experts.gate_up_proj.weight,
288-
self.shared_experts.gate_up_proj.weight_scale,
289-
output_dtype=torch.int32,
290-
)
291-
else:
292-
gate_up, _ = self.gate_up_proj(shared_hidden_states)
293-
294-
if cv_parallel:
295-
hidden_states, shared_output = self.experts(
296-
hidden_states=hidden_states,
297-
router_logits=router_logits,
298-
is_prefill=is_prefill,
299-
top_k=CustomDeepseekV2MoE.top_k,
300-
enable_force_load_balance=enable_force_load_balance,
301-
shared_experts=self.shared_experts,
302-
shared_gate_up=gate_up,
303-
shared_dynamic_scale=dynamic_scale)
304-
hidden_states = hidden_states * self.routed_scaling_factor
305-
else:
306-
hidden_states = self.experts(
307-
hidden_states=hidden_states,
308-
router_logits=router_logits,
309-
is_prefill=is_prefill,
310-
top_k=CustomDeepseekV2MoE.top_k,
311-
enable_force_load_balance=enable_force_load_balance,
312-
) * self.routed_scaling_factor
275+
kwargs = {}
276+
if multistream:
277+
kwargs.update({
278+
"shared_experts": self.shared_experts,
279+
"shared_hidden_states": shared_hidden_states
280+
})
281+
282+
hidden_states = self.experts(
283+
hidden_states=hidden_states,
284+
router_logits=router_logits,
285+
is_prefill=is_prefill,
286+
top_k=CustomDeepseekV2MoE.top_k,
287+
enable_force_load_balance=enable_force_load_balance,
288+
**kwargs)
289+
290+
if multistream:
291+
hidden_states, shared_output = hidden_states
292+
293+
hidden_states = hidden_states * self.routed_scaling_factor
313294

314295
if self.tp_size > 1:
315296
if self.enable_graph_mode:

vllm_ascend/ops/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -834,13 +834,13 @@ def __init__(
834834
self.quant_method.create_weights(layer=self, **moe_quant_params)
835835

836836
self.enable_graph_mode = False
837-
self.enable_cv_parallel = False
837+
self.enable_multistream_shared_expert = False
838838
additional_config = get_current_vllm_config().additional_config
839839
if additional_config:
840840
self.enable_graph_mode = additional_config.get(
841841
"enable_graph_mode", False)
842-
self.enable_cv_parallel = additional_config.get(
843-
"enable_cv_parallel", False)
842+
self.enable_multistream_shared_expert = additional_config.get(
843+
"enable_multistream_shared_expert", False)
844844

845845
def forward(self,
846846
hidden_states: torch.Tensor,
@@ -895,7 +895,7 @@ def forward(self,
895895
enable_force_load_balance=enable_force_load_balance,
896896
**kwargs)
897897

898-
if self.enable_cv_parallel and not is_prefill:
898+
if self.enable_multistream_shared_expert and not is_prefill:
899899
hidden_states, shared_output = hidden_states
900900

901901
if self.dp_size > 1:
@@ -920,6 +920,6 @@ def forward(self,
920920
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
921921
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
922922

923-
if self.enable_cv_parallel and not is_prefill:
923+
if self.enable_multistream_shared_expert and not is_prefill:
924924
return hidden_states, shared_output
925925
return hidden_states

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,24 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
184184
}
185185
kwargs_mc2.update(stage1_kwargs)
186186

187+
shared_experts = kwargs.get('shared_experts', None)
188+
if shared_experts:
189+
shared_hidden_states = kwargs.get('shared_hidden_states', None)
190+
with tng.scope.npu_stream_switch('cv'):
191+
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
192+
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
193+
shared_hidden_states)
194+
shared_gate_up = torch_npu.npu_quant_matmul(
195+
shared_x,
196+
shared_experts.gate_up_proj.weight,
197+
shared_experts.gate_up_proj.weight_scale,
198+
output_dtype=torch.int32,
199+
)
200+
kwargs.update({
201+
"shared_gate_up": shared_gate_up,
202+
"shared_dynamic_scale": shared_dynamic_scale,
203+
})
204+
187205
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
188206
# comm_stream.wait_stream(torch.npu.current_stream())
189207
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[

0 commit comments

Comments
 (0)