Skip to content

Commit f1543d5

Browse files
authored
[bugfix] fix deeepseek accuracy (#1118)
### What this PR does / why we need it? fix deeepseek accuracy in mix-parallel case. Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent c874214 commit f1543d5

File tree

3 files changed

+23
-27
lines changed

3 files changed

+23
-27
lines changed

vllm_ascend/models/deepseek_v2.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767

6868
import vllm_ascend.envs as envs_ascend
6969
from vllm_ascend.ascend_config import get_ascend_config
70+
from vllm_ascend.distributed.parallel_state import get_ep_group
7071
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7172
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7273
from vllm_ascend.utils import dispose_tensor
@@ -211,13 +212,15 @@ def __init__(
211212

212213
self.tp_group = get_tp_group().device_group
213214
self.tp_rank = get_tp_group().rank_in_group
215+
self.ep_group = get_ep_group()
214216

215217
self.params_dtype = torch.get_default_dtype()
216218

217219
ascend_config = get_ascend_config()
218220
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
221+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
219222
self.enable_multistream_shared_expert = \
220-
ascend_config.torchair_graph_config.enable_multistream_shared_expert
223+
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
221224

222225
def forward(
223226
self,
@@ -245,16 +248,12 @@ def forward(
245248
old_hidden_states = hidden_states.clone()
246249

247250
if self.tp_size > 1:
248-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
249-
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
250-
hidden_states = chunks[self.tp_rank]
251-
elif not self.torchair_graph_enabled:
252-
num_padding_tokens = (self.tp_size -
253-
num_tokens % self.tp_size) % self.tp_size
254-
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255-
if num_padding_tokens > 0:
251+
if (VLLM_ENABLE_MC2
252+
and not is_prefill) or not (self.torchair_graph_enabled or
253+
self.ep_group.world_size == 1):
254+
if num_tokens < self.tp_size:
256255
hidden_states = nn.functional.pad(
257-
hidden_states, (0, 0, 0, num_padding_tokens))
256+
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
258257
chunk_hidden_states = torch.tensor_split(hidden_states,
259258
self.tp_size,
260259
dim=0)
@@ -284,24 +283,16 @@ def forward(
284283
hidden_states = hidden_states * self.routed_scaling_factor
285284

286285
if self.tp_size > 1:
287-
if self.torchair_graph_enabled:
288-
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
289-
final_hidden_states = torch.zeros(
290-
[num_tokens, hidden_size],
291-
dtype=self.params_dtype,
292-
device="npu")
293-
dist.all_gather_into_tensor(final_hidden_states,
294-
hidden_states, self.tp_group)
295-
hidden_states = final_hidden_states
296-
else:
297-
hidden_states = tensor_model_parallel_all_reduce(
298-
hidden_states)
299-
else:
286+
if (VLLM_ENABLE_MC2
287+
and not is_prefill) or not (self.torchair_graph_enabled or
288+
self.ep_group.world_size == 1):
300289
dist.all_gather(list(chunk_hidden_states), hidden_states,
301290
self.tp_group)
302291
hidden_states = torch.cat(chunk_hidden_states, dim=0)
303-
if num_padding_tokens > 0:
304-
hidden_states = hidden_states[:-num_padding_tokens]
292+
if num_tokens < self.tp_size:
293+
hidden_states = hidden_states[:num_tokens]
294+
else:
295+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
305296

306297
if self.n_shared_experts is not None:
307298
if not multistream:

vllm_ascend/ops/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,8 +1027,9 @@ def __init__(
10271027

10281028
ascend_config = get_ascend_config()
10291029
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
1030+
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
10301031
self.enable_multistream_shared_expert = \
1031-
ascend_config.torchair_graph_config.enable_multistream_shared_expert
1032+
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
10321033

10331034
if self.scoring_func != "softmax" and not self.use_grouped_topk:
10341035
raise ValueError("Only softmax scoring function is supported for "

vllm_ascend/platform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
142142

143143
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
144144
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
145-
if ascend_config.expert_tensor_parallel_size > 0 and not parallel_config.enable_expert_parallel:
145+
if parallel_config.enable_expert_parallel:
146+
parallel_config.expert_tensor_parallel_size = 1
147+
# NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size`
148+
# is configured, use ascend_config
149+
elif ascend_config.expert_tensor_parallel_size > 0:
146150
parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size
147151

148152
# Calculate expert parallel size based on world size

0 commit comments

Comments
 (0)