Skip to content

Commit da9acfc

Browse files
authored
feat: support data parallel for deepseek (#1012)
### What this PR does / why we need it? feat: support data parallel for deepseek ### Does this PR introduce _any_ user-facing change? Yes, support dp for deepseek ### How was this patch tested? ``` export VLLM_ENABLE_MC2=0 export VLLM_USE_V1=1 export TASK_QUEUE_ENABLE=1 source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/nnal/atb/set_env.sh nohup python -m vllm.entrypoints.openai.api_server --model=/path/to/DeepSeek-R1-W8A8 \ --quantization ascend \ --served-model-name auto \ --trust-remote-code \ --distributed-executor-backend=mp \ --port 8006 \ -tp=8 \ -dp=2 \ --max-num-seqs 24 \ --max-model-len 4096 \ --max-num-batched-tokens 4096 \ --block-size 128 \ -O 0 \ --no-enable-prefix-caching \ --additional-config '{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}' \ --gpu-memory-utilization 0.95 &> run.log & disown ``` Signed-off-by: boying <897013703@qq.com>
1 parent 5178114 commit da9acfc

File tree

8 files changed

+212
-88
lines changed

8 files changed

+212
-88
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class AscendMLAMetadata:
117117
# For logging.
118118
num_input_tokens: int = 0 # Number of tokens including padding.
119119

120+
with_prefill_across_dp: bool = False
121+
120122
# The dimension of the attention heads
121123
head_dim: Optional[int] = None
122124
attn_mask: torch.Tensor = None
@@ -260,6 +262,10 @@ def build_dummy(self, num_reqs: int,
260262
PAD_SLOT_ID,
261263
dtype=torch.int32,
262264
device=device)
265+
query_start_loc = torch.full((num_reqs, ),
266+
-1,
267+
dtype=torch.int32,
268+
device=device)
263269
decode_metadata = AscendMLADecodeMetadata(
264270
input_positions=input_positions,
265271
block_table=block_table,
@@ -278,15 +284,21 @@ def build_dummy(self, num_reqs: int,
278284
attn_state=AscendAttentionState.DecodeOnly,
279285
prefill=None,
280286
decode=decode_metadata,
287+
query_start_loc=query_start_loc,
288+
seq_lens=seq_lens,
289+
block_tables=block_table,
281290
)
282291

283-
def build(self,
284-
num_reqs: int,
285-
num_actual_tokens: int,
286-
max_query_len: int,
287-
common_attn_metadata: CommonAttentionMetadata,
288-
common_prefix_len: Optional[int] = None,
289-
graph_pad_size: int = -1) -> AscendMLAMetadata:
292+
def build(
293+
self,
294+
num_reqs: int,
295+
num_actual_tokens: int,
296+
max_query_len: int,
297+
common_attn_metadata: CommonAttentionMetadata,
298+
common_prefix_len: Optional[int] = None,
299+
graph_pad_size: int = -1,
300+
with_prefill_across_dp: bool = False,
301+
) -> AscendMLAMetadata:
290302
assert self._num_decodes + self._num_prefills == num_reqs
291303

292304
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +400,7 @@ def build(self,
388400
query_start_loc=query_start_loc,
389401
block_tables=block_table,
390402
seq_lens=seq_lens,
403+
with_prefill_across_dp=with_prefill_across_dp,
391404
)
392405

393406

@@ -621,7 +634,7 @@ def exec_kv(
621634
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
622635
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
623636
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
624-
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
637+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
625638
kv,
626639
self.kv_a_layernorm.weight,
627640
cos,
@@ -643,7 +656,7 @@ def rope_single(
643656
B, N, D = x.shape
644657
S = 1
645658
x = x.view(B, N, S, D)
646-
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
659+
x = torch_npu.npu_interleave_rope(x, cos, sin)
647660
return x.view(B, N, D)
648661

649662
def _forward_decode(
@@ -766,6 +779,7 @@ def forward(
766779
sin = sin[attn_metadata.decode.input_positions]
767780
cos = cos[:, None, None, :]
768781
sin = sin[:, None, None, :]
782+
769783
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
770784
decode_k_pe, decode_k_nope = self.exec_kv(
771785
hidden_states_or_kv_c_normed, cos, sin, kv_cache,

vllm_ascend/models/deepseek_v2.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,14 @@ def __init__(
212212
self.tp_group = get_tp_group().device_group
213213
self.tp_rank = get_tp_group().rank_in_group
214214

215+
self.params_dtype = torch.get_default_dtype()
216+
217+
self.enable_graph_mode = False
218+
additional_config = get_current_vllm_config().additional_config
219+
if additional_config:
220+
self.enable_graph_mode = additional_config.get(
221+
"enable_graph_mode", False)
222+
215223
def forward(
216224
self,
217225
hidden_states: torch.Tensor,
@@ -228,52 +236,65 @@ def forward(
228236
else:
229237
is_prefill = attn_metadata.num_prefills > 0
230238
enable_force_load_balance = False
231-
num_tokens, hidden_dim = hidden_states.shape
239+
if hasattr(attn_metadata, 'with_prefill_across_dp'):
240+
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
241+
242+
num_tokens, hidden_size = hidden_states.shape
232243

233244
if self.n_shared_experts is not None:
234245
shared_output = self.shared_experts(hidden_states)
235246

236247
if self.tp_size > 1:
237-
# pass
238-
num_tokens, hidden_size = hidden_states.shape
239-
if num_tokens < self.tp_size:
240-
target_size = self.tp_size
241-
new_hidden_states = torch.empty([target_size, hidden_size],
242-
dtype=hidden_states.dtype,
243-
device=hidden_states.device)
244-
new_hidden_states[:num_tokens] = hidden_states
245-
hidden_states = new_hidden_states
246-
chunk_hidden_states = torch.tensor_split(hidden_states,
247-
self.tp_size,
248-
dim=0)
249-
local_hidden_states = chunk_hidden_states[self.tp_rank]
250-
else:
251-
local_hidden_states = hidden_states
248+
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
249+
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
250+
hidden_states = chunks[self.tp_rank]
251+
elif not self.enable_graph_mode:
252+
num_padding_tokens = (self.tp_size -
253+
num_tokens % self.tp_size) % self.tp_size
254+
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
255+
if num_padding_tokens > 0:
256+
hidden_states = nn.functional.pad(
257+
hidden_states, (0, 0, 0, num_padding_tokens))
258+
chunk_hidden_states = torch.tensor_split(hidden_states,
259+
self.tp_size,
260+
dim=0)
261+
hidden_states = chunk_hidden_states[self.tp_rank]
252262

253263
# router_logits: (num_tokens, n_experts)
254-
router_logits, _ = self.gate(local_hidden_states)
264+
router_logits, _ = self.gate(hidden_states)
255265

256-
router_hidden_states = self.experts(
257-
hidden_states=local_hidden_states,
266+
hidden_states = self.experts(
267+
hidden_states=hidden_states,
258268
router_logits=router_logits,
259269
is_prefill=is_prefill,
260270
top_k=CustomDeepseekV2MoE.top_k,
261271
enable_force_load_balance=enable_force_load_balance,
262272
) * self.routed_scaling_factor
263273

264274
if self.tp_size > 1:
265-
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
266-
self.tp_group)
267-
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
268-
if num_tokens < self.tp_size:
269-
final_hidden_states = final_hidden_states[:num_tokens]
270-
else:
271-
final_hidden_states = router_hidden_states
275+
if self.enable_graph_mode:
276+
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
277+
final_hidden_states = torch.zeros(
278+
[num_tokens, hidden_size],
279+
dtype=self.params_dtype,
280+
device="npu")
281+
dist.all_gather_into_tensor(final_hidden_states,
282+
hidden_states, self.tp_group)
283+
hidden_states = final_hidden_states
284+
else:
285+
hidden_states = tensor_model_parallel_all_reduce(
286+
hidden_states)
287+
else:
288+
dist.all_gather(list(chunk_hidden_states), hidden_states,
289+
self.tp_group)
290+
hidden_states = torch.cat(chunk_hidden_states, dim=0)
291+
if num_padding_tokens > 0:
292+
hidden_states = hidden_states[:-num_padding_tokens]
272293

273294
if shared_output is not None:
274-
final_hidden_states = final_hidden_states + shared_output
295+
hidden_states = hidden_states + shared_output
275296

276-
return final_hidden_states.view(num_tokens, hidden_dim)
297+
return hidden_states.view(num_tokens, hidden_size)
277298

278299

279300
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

vllm_ascend/ops/fused_moe.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,12 @@ def __init__(self, moe: MoEConfig = None):
587587
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
588588
self.local_batch_size = self.global_batch_size // self.ep_size
589589

590+
self.enable_graph_mode = False
591+
additional_config = get_current_vllm_config().additional_config
592+
if additional_config:
593+
self.enable_graph_mode = additional_config.get(
594+
"enable_graph_mode", False)
595+
590596
try:
591597
device_group = ep_group.device_group
592598
# TODO: Try local_rank = ep_group.rank_in_group
@@ -664,7 +670,7 @@ def apply(
664670
top_k=top_k,
665671
expert_map=expert_map,
666672
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
667-
elif get_ep_group().world_size == 1:
673+
elif self.enable_graph_mode or get_ep_group().world_size == 1:
668674
return fused_experts(hidden_states=x,
669675
w1=layer.w13_weight,
670676
w2=layer.w2_weight,
@@ -750,26 +756,20 @@ def __init__(
750756
self.expert_map = None
751757
self.activation = activation
752758

753-
if self.ep_size > 1:
754-
# Create a tensor of size num_experts filled with -1
755-
self.local_num_experts, self.expert_map = determine_expert_map(
756-
self.ep_size,
757-
get_ep_group().rank_in_group, self.global_num_experts)
758-
759-
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
760-
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
759+
# Create a tensor of size num_experts filled with -1
760+
self.local_num_experts, self.expert_map = determine_expert_map(
761+
self.ep_size,
762+
get_ep_group().rank_in_group, self.global_num_experts)
761763

762-
else:
763-
# Adjust TP size for DP attention
764-
# haven't test its functionality yet, may remove in the future
764+
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
765+
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
765766

766-
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
767-
self.moe_parallel_config.ep_rank = 0
768-
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
769-
self.moe_parallel_config.ep_size = 1
767+
self.enable_graph_mode = False
768+
additional_config = get_current_vllm_config().additional_config
769+
if additional_config:
770+
self.enable_graph_mode = additional_config.get(
771+
"enable_graph_mode", False)
770772

771-
self.local_num_experts, self.expert_map = (self.global_num_experts,
772-
None)
773773
if self.scoring_func != "softmax" and not self.use_grouped_topk:
774774
raise ValueError("Only softmax scoring function is supported for "
775775
"non-grouped topk.")
@@ -807,8 +807,15 @@ def __init__(
807807
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
808808
moe_quant_params["intermediate_size_full"] = intermediate_size
809809

810+
self.ep_group = get_ep_group()
810811
self.quant_method.create_weights(layer=self, **moe_quant_params)
811812

813+
self.enable_graph_mode = False
814+
additional_config = get_current_vllm_config().additional_config
815+
if additional_config:
816+
self.enable_graph_mode = additional_config.get(
817+
"enable_graph_mode", False)
818+
812819
def forward(self,
813820
hidden_states: torch.Tensor,
814821
router_logits: torch.Tensor,
@@ -822,11 +829,28 @@ def forward(self,
822829
else:
823830
real_top_k = self.top_k
824831

825-
if VLLM_ENABLE_MC2 and not is_prefill:
826-
...
832+
# MC2 ag/rs broadcast/all_reduce
833+
# prefill_req x x √
834+
# decode_req √ x √
835+
# graph_mode √ √ x
836+
if self.dp_size > 1:
837+
if VLLM_ENABLE_MC2 and not is_prefill:
838+
...
839+
elif self.enable_graph_mode:
840+
if USING_LCCL_COM: # type: ignore
841+
hidden_states = get_dp_group().all_gather(
842+
hidden_states, 0, False)
843+
router_logits = get_dp_group().all_gather(
844+
router_logits, 0, False)
845+
elif self.enable_graph_mode and not is_prefill:
846+
hidden_states = get_dp_group().all_gather(hidden_states, 0)
847+
router_logits = get_dp_group().all_gather(router_logits, 0)
848+
else:
849+
hidden_states, router_logits = get_ep_group().dispatch(
850+
hidden_states, router_logits)
827851

828852
# Matrix multiply.
829-
final_hidden_states = self.quant_method.apply(
853+
hidden_states = self.quant_method.apply(
830854
layer=self,
831855
x=hidden_states,
832856
router_logits=router_logits,
@@ -843,11 +867,26 @@ def forward(self,
843867
is_prefill=is_prefill,
844868
enable_force_load_balance=enable_force_load_balance)
845869

846-
if VLLM_ENABLE_MC2 and not is_prefill:
847-
...
870+
if self.dp_size > 1:
871+
if VLLM_ENABLE_MC2 and not is_prefill:
872+
...
873+
elif self.enable_graph_mode:
874+
if USING_LCCL_COM: # type: ignore
875+
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
876+
hidden_states,
877+
"sum",
878+
scatter_dim=0,
879+
group=get_dp_group().device_group)
880+
elif self.enable_graph_mode and not is_prefill:
881+
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
882+
hidden_states,
883+
"sum",
884+
scatter_dim=0,
885+
group=get_dp_group().device_group)
886+
else:
887+
hidden_states = get_ep_group().combine(hidden_states)
848888

849889
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
850-
final_hidden_states = tensor_model_parallel_all_reduce(
851-
final_hidden_states)
890+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
852891

853-
return final_hidden_states
892+
return hidden_states

vllm_ascend/platform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
138138

139139
# Calculate expert parallel size based on world size
140140
parallel_config.expert_parallel_size = (
141-
parallel_config.world_size //
141+
parallel_config.world_size_across_dp //
142142
parallel_config.expert_tensor_parallel_size)
143143

144144
if model_config is None:
@@ -167,6 +167,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
167167
raise NotImplementedError(
168168
"enable_graph_mode only works with deepseek model."
169169
)
170+
# Set compilation level to NO_COMPILATION to disable ACL Graph
171+
compilation_config.level = CompilationLevel.NO_COMPILATION
170172

171173
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
172174
model_type = model_config.hf_config.model_type

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.distributed as dist
2222
import torch_npu
23+
from vllm.config import get_current_vllm_config
2324
from vllm.distributed import GroupCoordinator
2425

2526
import vllm_ascend.envs as envs_ascend
@@ -508,6 +509,12 @@ def __init__(self):
508509

509510
self.ep_group = get_ep_group()
510511

512+
self.enable_graph_mode = False
513+
additional_config = get_current_vllm_config().additional_config
514+
if additional_config:
515+
self.enable_graph_mode = additional_config.get(
516+
"enable_graph_mode", False)
517+
511518
try:
512519
device_group = self.ep_group.device_group
513520
# TODO: Try local_rank = ep_group.rank_in_group
@@ -629,7 +636,7 @@ def apply(
629636
top_k=top_k,
630637
expert_map=expert_map,
631638
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
632-
elif self.ep_group.world_size == 1:
639+
elif self.enable_graph_mode or self.ep_group.world_size == 1:
633640
return fused_experts(hidden_states=x,
634641
w1=layer.w13_weight,
635642
w1_scale=layer.w13_weight_scale,

0 commit comments

Comments
 (0)