Skip to content

Commit 5c6d05a

Browse files
zzzzwwjjwen-jie666
andauthored
support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
1 parent e74331a commit 5c6d05a

File tree

13 files changed

+522
-223
lines changed

13 files changed

+522
-223
lines changed

examples/dp_offline/data_parallel.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,15 @@
1111
import gc
1212
import os
1313

14-
VLLM_ENABLE_GRAPGH_MODE = os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1"
15-
1614

1715
def main():
1816
dp_rank = int(os.environ['RANK'])
1917
local_rank = int(os.environ['LOCAL_RANK'])
2018
dp_size = int(os.environ['WORLD_SIZE'])
2119
master_addr = os.environ['MASTER_ADDR']
2220
master_port = os.environ['MASTER_PORT']
23-
tp_size = 4
24-
etp_size = 2
21+
tp_size = 1
22+
etp_size = 1
2523

2624
os.environ["VLLM_DP_RANK"] = str(dp_rank)
2725
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -58,15 +56,15 @@ def main():
5856
max_tokens=4,
5957
min_tokens=4)
6058
# Create an LLM.
61-
llm = LLM(
62-
model="deepseek-ai/DeepSeek-V2-Lite-Chat",
63-
tensor_parallel_size=tp_size,
64-
trust_remote_code=True,
65-
expert_tensor_parallel_size=etp_size,
66-
max_model_len=4096,
67-
max_num_seqs=num_seqs,
68-
compilation_config=1 if VLLM_ENABLE_GRAPGH_MODE else 0,
69-
)
59+
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite-Chat",
60+
tensor_parallel_size=tp_size,
61+
trust_remote_code=True,
62+
max_model_len=4096,
63+
max_num_seqs=num_seqs,
64+
additional_config={
65+
'expert_tensor_parallel_size': etp_size,
66+
'enable_graph_mode': False,
67+
})
7068

7169
outputs = llm.generate(prompts, sampling_params)
7270
for output in outputs:

examples/dp_offline/run_dp.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@ export HCCL_SOCKET_IFNAME=${ifname}
66
# dp_size = node_size * dp_per_node
77
node_size=1
88
node_rank=0
9-
dp_per_node=2
9+
dp_per_node=4
1010
master_addr=127.0.0.1
1111
master_port=12345
1212

1313
rm -rf ./.torchair_cache/
1414
rm -rf ./dynamo_*
1515
rm -rf /root/ascend/log/debug/plog/*
16-
export VLLM_ENABLE_GRAPH_MODE=0
17-
export VLLM_ENABLE_MC2=0
1816

1917
torchrun --nproc_per_node ${dp_per_node} --nnodes ${node_size} \
2018
--node_rank ${node_rank} --master_addr ${master_addr} --master_port ${master_port} \

vllm_ascend/attention/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
except ImportError:
2828
print("Failed to import torch_npu.")
2929

30+
import torchair._contrib.custom_torch_ops # type: ignore # noqa: F401
3031
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
3132
AttentionLayer,
3233
AttentionMetadata, AttentionType,
@@ -36,9 +37,9 @@
3637
compute_slot_mapping,
3738
compute_slot_mapping_start_idx,
3839
is_block_tables_empty)
40+
from vllm.config import get_current_vllm_config
3941
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
4042

41-
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
4243
from vllm_ascend.worker.model_runner import (
4344
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
4445

@@ -913,6 +914,12 @@ def __init__(
913914
self.w_kc = None
914915
self.w_vc = None
915916

917+
self.enable_graph_mode = False
918+
additional_config = get_current_vllm_config().additional_config
919+
if additional_config:
920+
self.enable_graph_mode = additional_config.get(
921+
"enable_graph_mode", False)
922+
916923
def exec_kv(
917924
self,
918925
hidden_states: torch.Tensor,
@@ -1084,7 +1091,7 @@ def forward(
10841091
self.num_heads, -1)
10851092

10861093
# TODO: Replace the env with more flexible expressions
1087-
if VLLM_ENABLE_GRAPH_MODE == '1':
1094+
if self.enable_graph_mode:
10881095
if len(kv_cache) > 0 and kv_cache[0].numel(
10891096
) > 0 and attn_metadata.num_prefills > 0:
10901097
slots = attn_metadata.slot_mapping
@@ -1141,7 +1148,7 @@ def forward(
11411148
)
11421149
elif attn_metadata.decode_metadata:
11431150
assert kv_cache is not None
1144-
if VLLM_ENABLE_GRAPH_MODE == '1':
1151+
if self.enable_graph_mode:
11451152
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
11461153
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
11471154
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)

vllm_ascend/models/deepseek_v2.py

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

2828
import os
29-
from typing import Any, Dict, Optional, Union
29+
from typing import Any, Dict, List, Optional, Union
3030

3131
import torch
3232
import torch.distributed as dist
3333
from torch import nn
3434
from transformers import PretrainedConfig
35-
from vllm.attention import Attention
35+
from vllm.attention import Attention, AttentionMetadata
3636
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
3737
get_current_vllm_config)
3838
from vllm.distributed import (get_dp_group, get_pp_group,
@@ -64,7 +64,6 @@
6464
from vllm.sequence import IntermediateTensors
6565

6666
from vllm_ascend.ops.fused_moe import AscendFusedMoE
67-
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
6867

6968

7069
class CustomDeepseekV2MoE(nn.Module):
@@ -133,7 +132,7 @@ def __init__(
133132
vllm_config = get_current_vllm_config()
134133
self.dp_size = get_dp_group().world_size
135134
batch_size = vllm_config.scheduler_config.max_num_seqs
136-
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
135+
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
137136

138137
params_dtype = torch.get_default_dtype()
139138
self.final_hidden_states = torch.zeros(
@@ -309,38 +308,36 @@ def __init__(
309308

310309
self.prefix = prefix
311310
self.debug_layer_idx = int(self.prefix.split(".")[-2])
312-
if VLLM_ENABLE_GRAPH_MODE == "1":
313-
self.forward = self.forward_torchair
314-
else:
315-
self.forward = self.forward_eager # type: ignore
311+
self.enable_graph_mode = False
312+
additional_config = get_current_vllm_config().additional_config
313+
if additional_config:
314+
self.enable_graph_mode = additional_config.get(
315+
"enable_graph_mode", False)
316316

317-
def forward_torchair(self,
318-
positions: torch.Tensor,
319-
hidden_states: torch.Tensor,
320-
kv_cache: torch.Tensor = None,
321-
attn_metadata=None):
317+
def forward(
318+
self,
319+
positions: torch.Tensor,
320+
hidden_states: torch.Tensor,
321+
kv_cache: Optional[torch.Tensor] = None,
322+
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
322323
if self.q_lora_rank is not None:
323324
ckq = self.q_a_proj(hidden_states)[0]
324325
hidden_states_or_q_c = self.q_a_layernorm(ckq)
325326
else:
326327
hidden_states_or_q_c = hidden_states
327-
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
328-
kv_cache, attn_metadata)
329-
330-
def forward_eager(self, positions: torch.Tensor,
331-
hidden_states: torch.Tensor):
332-
if self.q_lora_rank is not None:
333-
ckq = self.q_a_proj(hidden_states)[0]
334-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
328+
if self.enable_graph_mode:
329+
return self.mla_attn.impl.forward(self.mla_attn,
330+
hidden_states_or_q_c,
331+
hidden_states, None, kv_cache,
332+
attn_metadata)
335333
else:
336-
hidden_states_or_q_c = hidden_states
337-
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
338-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
339-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
340-
return self.mla_attn(hidden_states_or_q_c,
341-
kv_c_normed,
342-
k_pe,
343-
output_shape=hidden_states.shape)
334+
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
335+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
336+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
337+
return self.mla_attn(hidden_states_or_q_c,
338+
kv_c_normed,
339+
k_pe,
340+
output_shape=hidden_states.shape)
344341

345342

346343
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -408,6 +405,54 @@ def __init__(
408405
eps=config.rms_norm_eps)
409406
self.routed_scaling_factor = config.routed_scaling_factor
410407

408+
def forward(
409+
self,
410+
positions: torch.Tensor,
411+
hidden_states: torch.Tensor,
412+
residual: Optional[torch.Tensor],
413+
kv_cache: Optional[torch.Tensor] = None,
414+
attn_metadata: Optional[AttentionMetadata] = None,
415+
) -> torch.Tensor:
416+
# Self Attention
417+
if residual is None:
418+
residual = hidden_states
419+
hidden_states = self.input_layernorm(hidden_states)
420+
else:
421+
hidden_states, residual = self.input_layernorm(
422+
hidden_states, residual)
423+
hidden_states = self.self_attn(
424+
positions=positions,
425+
hidden_states=hidden_states,
426+
kv_cache=kv_cache,
427+
attn_metadata=attn_metadata,
428+
)
429+
430+
if hidden_states.dtype == torch.float16:
431+
# Fix FP16 overflow
432+
# We scale both hidden_states and residual before
433+
# rmsnorm, and rmsnorm result would not affect by scale.
434+
hidden_states *= 1. / self.routed_scaling_factor
435+
if self.layer_idx == 0:
436+
# The residual is shared by all layers, we only scale it on
437+
# first layer.
438+
residual *= 1. / self.routed_scaling_factor
439+
440+
# Fully Connected
441+
hidden_states, residual = self.post_attention_layernorm(
442+
hidden_states, residual)
443+
hidden_states = self.mlp(hidden_states)
444+
445+
if isinstance(self.mlp,
446+
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
447+
# Fix FP16 overflow
448+
# Scaling the DeepseekV2MLP output, it is the input of
449+
# input_layernorm of next decoder layer.
450+
# The scaling of DeepseekV2MOE output would be done in the forward
451+
# of DeepseekV2MOE
452+
hidden_states *= 1. / self.routed_scaling_factor
453+
454+
return hidden_states, residual
455+
411456

412457
class CustomDeepseekV2Model(nn.Module):
413458

@@ -459,7 +504,9 @@ def forward(
459504
self,
460505
input_ids: torch.Tensor,
461506
positions: torch.Tensor,
462-
intermediate_tensors: Optional[IntermediateTensors],
507+
kv_caches: Optional[List[torch.Tensor]] = None,
508+
attn_metadata: Optional[AttentionMetadata] = None,
509+
intermediate_tensors: Optional[IntermediateTensors] = None,
463510
inputs_embeds: Optional[torch.Tensor] = None,
464511
) -> Union[torch.Tensor, IntermediateTensors]:
465512
if get_pp_group().is_first_rank:
@@ -473,8 +520,13 @@ def forward(
473520
hidden_states = intermediate_tensors["hidden_states"]
474521
residual = intermediate_tensors["residual"]
475522

476-
for layer in self.layers[self.start_layer:self.end_layer]:
477-
hidden_states, residual = layer(positions, hidden_states, residual)
523+
for i in range(self.start_layer, self.end_layer):
524+
layer = self.layers[i]
525+
hidden_states, residual = layer(
526+
positions, hidden_states, residual,
527+
kv_caches[i -
528+
self.start_layer] if kv_caches is not None else None,
529+
attn_metadata)
478530

479531
if not get_pp_group().is_last_rank:
480532
return IntermediateTensors({
@@ -514,6 +566,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514566
self.make_empty_intermediate_tensors = (
515567
self.model.make_empty_intermediate_tensors)
516568

569+
def forward(
570+
self,
571+
input_ids: torch.Tensor,
572+
positions: torch.Tensor,
573+
kv_caches: Optional[List[torch.Tensor]] = None,
574+
attn_metadata: Optional[AttentionMetadata] = None,
575+
intermediate_tensors: Optional[IntermediateTensors] = None,
576+
inputs_embeds: Optional[torch.Tensor] = None,
577+
) -> Union[torch.Tensor, IntermediateTensors]:
578+
hidden_states = self.model(input_ids, positions, kv_caches,
579+
attn_metadata, intermediate_tensors,
580+
inputs_embeds)
581+
return hidden_states
582+
517583

518584
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
519585
pass

0 commit comments

Comments
 (0)