Skip to content

Commit 84e2ed8

Browse files
authored
performance optimization, usability optimization and API compatibility adjustments for deepseek with npu graph mode (#731)
--> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> 1. Improve inference speed and usability for deepsek models with NPU graph mode. 2. Modify some codes to adapt to CANN 8.1.RC1.beta1. 3. Add a switch for NPU graph mode and its cache. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> This PR provides an experimental configuration to enable NPU graph mode for Deepseek models. User can set additional_config={'enable_graph_mode': True} to try this feature. Note that this feature currently only supports for V0 engine. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> This patch was tested with the newest torch_npu 2.5.1 (https://pypi.org/project/torch-npu/#files) and CANN 8.1.RC1.beta1 toolkit&nnal&kernels (https://www.hiascend.com/developer/download/community/result?module=cann) released in 25/30 April. Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent 399b038 commit 84e2ed8

File tree

6 files changed

+164
-52
lines changed

6 files changed

+164
-52
lines changed

vllm_ascend/attention/attention.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,14 +590,14 @@ def build(
590590
self.input_builder.chunked_prefill_enabled)
591591

592592
device = self.runner.device
593-
use_torchair_graph = graph_pad_size != -1
593+
use_npu_graph = graph_pad_size != -1
594594

595595
max_query_len = max(query_lens)
596596
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
597597
max_decode_seq_len = max(self.curr_seq_lens, default=0)
598598
num_decode_tokens = self.num_decode_tokens
599599

600-
if self.num_prefills == 0 and use_torchair_graph:
600+
if self.num_prefills == 0 and use_npu_graph:
601601
num_seqs = len(seq_lens)
602602
self.slot_mapping.extend([PAD_SLOT_ID] * graph_pad_size)
603603
self.block_tables.extend([[]] * graph_pad_size)
@@ -915,7 +915,7 @@ def exec_kv(
915915
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
916916
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
917917

918-
k_pe, k_nope = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
918+
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
919919
kv,
920920
self.kv_a_layernorm.weight,
921921
cos,
@@ -1123,24 +1123,32 @@ def forward(
11231123
elif attn_metadata.decode_metadata:
11241124
assert kv_cache is not None
11251125
if self.enable_graph_mode:
1126-
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
1126+
# shape of query for npu graph mode should be:
1127+
# [bs, num_heads_per_rank, seq_len, dim]
11271128
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
11281129
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
1130+
# shape of knope/k_pe for npu graph mode should be:
1131+
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
1132+
block_size = kv_cache[0].shape[1]
1133+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
1134+
self.kv_lora_rank)
1135+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
1136+
self.qk_rope_head_dim)
11291137
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
11301138
q_nope,
11311139
k_nope,
11321140
k_nope,
11331141
query_rope=q_pe,
11341142
key_rope=k_pe,
11351143
num_heads=self.num_heads,
1136-
num_key_value_heads=1,
1144+
num_key_value_heads=self.num_kv_heads,
11371145
input_layout="BNSD",
11381146
atten_mask=attn_metadata.attn_mask,
11391147
scale=self.scale,
11401148
antiquant_mode=0,
11411149
antiquant_scale=None,
11421150
block_table=attn_metadata.block_tables,
1143-
block_size=kv_cache[0].shape[1],
1151+
block_size=block_size,
11441152
actual_seq_lengths_kv=attn_metadata.seq_lens,
11451153
)
11461154
attn_output = attn_output.view(num_tokens, -1,

vllm_ascend/models/deepseek_v2.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import torch
3232
import torch.distributed as dist
33+
import torch_npu
3334
from torch import nn
3435
from transformers import PretrainedConfig
3536
from vllm.attention import Attention, AttentionMetadata
@@ -39,10 +40,13 @@
3940
get_tensor_model_parallel_world_size,
4041
get_tp_group, tensor_model_parallel_all_reduce)
4142
from vllm.forward_context import get_forward_context
43+
from vllm.model_executor.layers.activation import SiluAndMul
4244
from vllm.model_executor.layers.layernorm import RMSNorm
4345
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
46+
MergedColumnParallelLinear,
4447
ReplicatedLinear,
45-
RowParallelLinear)
48+
RowParallelLinear,
49+
UnquantizedLinearMethod)
4650
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4751
from vllm.model_executor.layers.quantization import QuantizationConfig
4852
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -55,15 +59,84 @@
5559
yarn_get_mscale # ruff: noqa: E501
5660
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
5761
DeepseekV2DecoderLayer,
58-
DeepseekV2MLAAttention,
59-
DeepseekV2MLP)
62+
DeepseekV2MLAAttention)
6063
from vllm.model_executor.models.utils import (
6164
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
6265
maybe_prefix)
63-
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
6466
from vllm.sequence import IntermediateTensors
6567

6668
from vllm_ascend.ops.fused_moe import AscendFusedMoE
69+
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
70+
71+
72+
class CustomDeepseekV2MLP(nn.Module):
73+
74+
def __init__(
75+
self,
76+
hidden_size: int,
77+
intermediate_size: int,
78+
hidden_act: str,
79+
quant_config: Optional[QuantizationConfig] = None,
80+
reduce_results: bool = True,
81+
prefix: str = "",
82+
) -> None:
83+
super().__init__()
84+
self.gate_up_proj = MergedColumnParallelLinear(
85+
hidden_size, [intermediate_size] * 2,
86+
bias=False,
87+
quant_config=quant_config,
88+
prefix=f"{prefix}.gate_up_proj")
89+
self.down_proj = RowParallelLinear(intermediate_size,
90+
hidden_size,
91+
bias=False,
92+
quant_config=quant_config,
93+
reduce_results=reduce_results,
94+
prefix=f"{prefix}.down_proj")
95+
if hidden_act != "silu":
96+
raise ValueError(f"Unsupported activation: {hidden_act}. "
97+
"Only silu is supported for now.")
98+
self.act_fn = SiluAndMul()
99+
100+
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
101+
self.is_dynamic_quant = not isinstance(
102+
self.gate_up_proj.quant_method,
103+
UnquantizedLinearMethod) and isinstance(
104+
self.gate_up_proj.quant_method.quant_method,
105+
AscendW8A8DynamicLinearMethod)
106+
107+
def forward(self, x):
108+
if self.is_dynamic_quant:
109+
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
110+
x = torch_npu.npu_quant_matmul(
111+
x,
112+
self.gate_up_proj.weight,
113+
self.gate_up_proj.weight_scale,
114+
output_dtype=torch.int32,
115+
)
116+
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
117+
x=x,
118+
weight_scale=self.gate_up_proj.weight_scale_fp32,
119+
activation_scale=dynamic_scale,
120+
bias=None,
121+
quant_scale=None,
122+
quant_offset=None,
123+
group_index=None,
124+
activate_left=True,
125+
quant_mode=1)
126+
x = torch_npu.npu_quant_matmul(
127+
x,
128+
self.down_proj.weight,
129+
self.down_proj.weight_scale,
130+
pertoken_scale=dynamic_scale,
131+
output_dtype=torch.bfloat16,
132+
)
133+
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
134+
x = tensor_model_parallel_all_reduce(x)
135+
return x
136+
gate_up, _ = self.gate_up_proj(x)
137+
x = self.act_fn(gate_up)
138+
x, _ = self.down_proj(x)
139+
return x
67140

68141

69142
class CustomDeepseekV2MoE(nn.Module):
@@ -119,7 +192,7 @@ def __init__(
119192
if config.n_shared_experts is not None:
120193
intermediate_size = (config.moe_intermediate_size *
121194
config.n_shared_experts)
122-
self.shared_experts = DeepseekV2MLP(
195+
self.shared_experts = CustomDeepseekV2MLP(
123196
hidden_size=config.hidden_size,
124197
intermediate_size=intermediate_size,
125198
hidden_act=config.hidden_act,
@@ -392,7 +465,7 @@ def __init__(
392465
prefix=f"{prefix}.mlp",
393466
)
394467
else:
395-
self.mlp = DeepseekV2MLP(
468+
self.mlp = CustomDeepseekV2MLP(
396469
hidden_size=config.hidden_size,
397470
intermediate_size=config.intermediate_size,
398471
hidden_act=config.hidden_act,
@@ -442,8 +515,9 @@ def forward(
442515
hidden_states, residual)
443516
hidden_states = self.mlp(hidden_states)
444517

445-
if isinstance(self.mlp,
446-
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
518+
if isinstance(
519+
self.mlp,
520+
CustomDeepseekV2MLP) and hidden_states.dtype == torch.float16:
447521
# Fix FP16 overflow
448522
# Scaling the DeepseekV2MLP output, it is the input of
449523
# input_layernorm of next decoder layer.
@@ -582,4 +656,4 @@ def forward(
582656

583657

584658
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
585-
pass
659+
pass

vllm_ascend/ops/rotary_embedding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,16 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
221221
t = torch.arange(seq_len, device=device, dtype=torch.float32)
222222

223223
freqs = torch.outer(t, inv_freq)
224+
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
225+
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
226+
cos_cached = cos_cached.to(dtype)
227+
sin_cached = sin_cached.to(dtype)
224228
cache = torch.cat([freqs.cos() * self.mscale,
225229
freqs.sin() * self.mscale],
226230
dim=-1).to(dtype)
227231
self.register_buffer("cos_sin_cache", cache, persistent=False)
232+
self.register_buffer("cos_cached", cos_cached, persistent=False)
233+
self.register_buffer("sin_cached", sin_cached, persistent=False)
228234

229235

230236
def deepseek_rope_init_func(

vllm_ascend/platform.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
124124
enforce_eager = True
125125
logger.warning(
126126
"NPU compilation support pending. Will be available in future CANN and "
127-
"torch_npu releases. Using default: enforce_eager=True")
127+
"torch_npu releases. NPU graph mode is currently experimental and disabled "
128+
"by default. You can just adopt additional_config={'enable_graph_mode': True} "
129+
"to serve deepseek models with NPU graph mode on vllm-ascend with V0 engine. "
130+
)
128131

129132
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
130133
logger.info("Compilation disabled, using eager mode by default")
@@ -150,6 +153,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
150153
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
151154
)
152155
vllm_config.additional_config["enable_graph_mode"] = False
156+
if enable_graph_mode and envs.VLLM_USE_V1:
157+
logger.warning(
158+
"NPU graph mode is still experimental and not supported for V1 currently, "
159+
"it has been disabled automatically.")
160+
vllm_config.additional_config["enable_graph_mode"] = False
153161

154162
parallel_config = vllm_config.parallel_config
155163
if parallel_config and parallel_config.worker_cls == "auto":

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,38 @@ def apply_mlp(x: torch.Tensor,
6262
h = x
6363
pertoken_scale = dynamic_scale
6464

65-
output_dtype = torch.bfloat16 if w1_scale.dtype == torch.bfloat16 else \
66-
torch.float16
67-
6865
# gmm1: gate_up_proj
69-
gate_up_out_list = torch_npu.npu_grouped_matmul(
70-
x=[h],
71-
weight=[w1],
72-
scale=[w1_scale],
73-
per_token_scale=[pertoken_scale],
74-
split_item=3,
75-
group_list_type=group_list_type,
76-
group_type=0,
77-
group_list=group_list,
78-
output_dtype=output_dtype)
79-
gate_up_out = gate_up_out_list[0]
80-
81-
# swiglu
82-
swiglu_out = torch_npu.npu_swiglu(gate_up_out)
83-
swiglu_out, swiglu_out_scale = torch_npu.npu_dynamic_quant(swiglu_out)
66+
gate_up_out = torch_npu.npu_grouped_matmul(x=[h],
67+
weight=[w1],
68+
split_item=3,
69+
group_list_type=group_list_type,
70+
group_type=0,
71+
group_list=group_list,
72+
output_dtype=torch.int32)[0]
73+
74+
swiglu_out, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
75+
x=gate_up_out,
76+
weight_scale=w1_scale,
77+
activation_scale=pertoken_scale,
78+
bias=None,
79+
quant_scale=None,
80+
quant_offset=None,
81+
group_index=group_list,
82+
activate_left=True,
83+
quant_mode=1,
84+
)
8485

8586
# down_proj
86-
down_out_list = torch_npu.npu_grouped_matmul(
87-
x=[swiglu_out],
88-
weight=[w2],
89-
scale=[w2_scale],
90-
per_token_scale=[swiglu_out_scale],
91-
split_item=3,
92-
group_list_type=group_list_type,
93-
group_type=0,
94-
group_list=group_list,
95-
output_dtype=output_dtype)
96-
return down_out_list[0]
87+
down_out = torch_npu.npu_grouped_matmul(x=[swiglu_out],
88+
weight=[w2],
89+
scale=[w2_scale],
90+
per_token_scale=[swiglu_out_scale],
91+
split_item=2,
92+
group_list_type=group_list_type,
93+
group_type=0,
94+
group_list=group_list,
95+
output_dtype=w2_scale.dtype)[0]
96+
return down_out
9797

9898

9999
def fused_experts_with_mc2(
@@ -363,7 +363,10 @@ def apply(
363363
def process_weights_after_loading(self, layer):
364364
if self.transpose_weight:
365365
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
366+
# cast quantized weight tensors in NZ format (29) for higher inference speed
367+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
366368
layer.weight_scale.data = layer.weight_scale.data.flatten()
369+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
367370
layer.weight_offset.data = layer.weight_offset.data.flatten()
368371

369372

@@ -508,7 +511,7 @@ def process_weights_after_loading(self, layer):
508511
layer.w2_weight.data = layer.w2_weight.data.transpose(
509512
1, 2).contiguous()
510513
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
511-
layer.w13_weight_scale.data.shape[0], -1)
514+
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
512515
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
513516
layer.w13_weight_offset.data.shape[0], -1)
514517
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(

vllm_ascend/worker/model_runner.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969

7070
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
7171
ENCODER_NUM = 0
72+
# if True, allow tensor initialization and casting with internal format (e.g., NZ)
73+
torch.npu.config.allow_internal_format = True
7274

7375

7476
@dataclass(frozen=True)
@@ -864,10 +866,13 @@ def __init__(
864866
self.vllm_config.compilation_config.max_capture_size
865867

866868
self.enable_graph_mode = False
869+
self.use_cached_npu_graph = False
867870
additional_config = vllm_config.additional_config
868871
if additional_config:
869872
self.enable_graph_mode = additional_config.get(
870873
"enable_graph_mode", False)
874+
self.use_cached_npu_graph = additional_config.get(
875+
"use_cached_npu_graph", False)
871876

872877
self.has_inner_state = model_config.has_inner_state
873878

@@ -981,12 +986,20 @@ def load_model(self) -> None:
981986
config.experimental_config.frozen_parameter = True
982987
config.experimental_config.tiling_schedule_optimize = True
983988
torch.npu.set_compile_mode(jit_compile=False)
984-
self.compile_model = torchair.inference.cache_compile(
985-
self.model.forward,
986-
dynamic=True,
987-
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
988-
config=config,
989-
ge_cache=False)
989+
if not self.use_cached_npu_graph:
990+
npu_backend = torchair.get_npu_backend(compiler_config=config)
991+
self.compile_model = torch.compile(
992+
self.model,
993+
dynamic=True,
994+
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
995+
backend=npu_backend)
996+
else:
997+
self.compile_model = torchair.inference.cache_compile(
998+
self.model.forward,
999+
dynamic=True,
1000+
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1001+
config=config,
1002+
ge_cache=False)
9901003

9911004
def save_sharded_state(
9921005
self,

0 commit comments

Comments
 (0)