Skip to content

Commit 0fc9b56

Browse files
authored
[Perf] Improve MLA multistream performance (#1353)
### What this PR does / why we need it? > Need to merge after PR #1322 According to benchmark results, this PR brings approximately 1% performance gain. #### Before Improvement Profiling <img width="1147" alt="截屏2025-06-22 14 54 47" src="https://github.com/user-attachments/assets/4a4dc7f1-5b76-45d5-864d-dd7f8faf993c" /> Evaluation ``` # server launch command python -m vllm.entrypoints.openai.api_server --model=/DeepSeek-R1-W8A8 \ --quantization ascend \ --served-model-name auto \ --trust-remote-code \ --distributed-executor-backend=mp \ --port 8006 \ -tp=16 \ --max-num-seqs 24 \ --max-model-len 32768 \ --max-num-batched-tokens 8192 \ --block-size 128 \ --no-enable-prefix-caching \ --additional-config '{"torchair_graph_config":{"enable_multistream_mla": true,"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24]},"ascend_scheduler_config":{"enabled":true},"expert_tensor_parallel_size":16}' \ --gpu-memory-utilization 0.96 # client benchmark command python /root/vllm/benchmarks/benchmark_serving.py --backend vllm --dataset-name random \ --random-input-len 4096 \ --random-output-len 1536 \ --num-prompts 200 \ --ignore-eos \ --model auto \ --tokenizer /DeepSeek-R1-W8A8 \ --port 8006 \ --request-rate 1 \ --max-concurrency 24 \ --save-result \ --skip-initial-test \ --metric-percentiles "50,90,99" ``` ``` ============ Serving Benchmark Result ============ Successful requests: 200 Benchmark duration (s): 958.59 Total input tokens: 819200 Total generated tokens: 307200 Request throughput (req/s): 0.2086 Output token throughput (tok/s): 320.47 Total Token throughput (tok/s): 1175.05 ---------------Time to First Token---------------- Mean TTFT (ms): 942.70 Median TTFT (ms): 713.87 P50 TTFT (ms): 713.87 P90 TTFT (ms): 1363.88 P99 TTFT (ms): 2008.73 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 68.96 Median TPOT (ms): 69.49 P50 TPOT (ms): 69.49 P90 TPOT (ms): 70.42 P99 TPOT (ms): 70.72 ---------------Inter-token Latency---------------- Mean ITL (ms): 68.96 Median ITL (ms): 59.88 P50 ITL (ms): 59.88 P90 ITL (ms): 61.59 P99 ITL (ms): 68.82 ================================================== ``` #### After Improvement Profiling <img width="1200" alt="截屏2025-06-22 14 55 42" src="https://github.com/user-attachments/assets/e3eb9dec-0ff0-4e5f-ab94-93c65003e51f" /> Evaluation ``` ============ Serving Benchmark Result ============ Successful requests: 200 Benchmark duration (s): 948.08 Total input tokens: 819200 Total generated tokens: 307200 Request throughput (req/s): 0.2110 Output token throughput (tok/s): 324.02 Total Token throughput (tok/s): 1188.08 ---------------Time to First Token---------------- Mean TTFT (ms): 1019.25 Median TTFT (ms): 714.63 P50 TTFT (ms): 714.63 P90 TTFT (ms): 1367.31 P99 TTFT (ms): 2661.52 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 68.14 Median TPOT (ms): 68.68 P50 TPOT (ms): 68.68 P90 TPOT (ms): 69.33 P99 TPOT (ms): 70.30 ---------------Inter-token Latency---------------- Mean ITL (ms): 68.14 Median ITL (ms): 59.04 P50 ITL (ms): 59.04 P90 ITL (ms): 60.93 P99 ITL (ms): 66.89 ================================================== ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@65393ee Signed-off-by: ApsarasX <apsarax@outlook.com>
1 parent cc210f4 commit 0fc9b56

File tree

3 files changed

+58
-30
lines changed

3 files changed

+58
-30
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm_ascend.multistream.context import get_multistream_comm_context
2222
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2323
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
24-
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
24+
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
2525
from vllm_ascend.worker.npu_input_batch import InputBatch
2626

2727
if TYPE_CHECKING:
@@ -579,13 +579,18 @@ def __init__(
579579
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
580580
"{32, 64, 128}.")
581581

582-
def _v_up_proj_and_o_proj(self, x):
582+
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
583583
# Convert from (B, N, L) to (N, B, L)
584584
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
585585
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
586586
x = torch.bmm(x, self.W_UV)
587587
# Convert from (N, B, V) to (B, N * V)
588588
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
589+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
590+
npu_prefetch(self.o_proj.weight,
591+
x,
592+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
593+
enabled=enable_multistream_mla)
589594
return self.o_proj(x, is_prefill=False)[0]
590595

591596
# Return `ql_nope`, `q_pe`
@@ -864,7 +869,6 @@ def exec_kv(
864869
sin: torch.Tensor,
865870
kv_cache: Tuple,
866871
slots: torch.Tensor,
867-
enable_multistream_mla: bool = False,
868872
):
869873

870874
B = hidden_states.shape[0]
@@ -874,21 +878,18 @@ def exec_kv(
874878
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
875879
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
876880
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
877-
with npu_stream_switch("mla_secondary",
878-
0,
879-
enabled=enable_multistream_mla):
880-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
881-
kv,
882-
self.kv_a_layernorm.weight,
883-
cos,
884-
sin,
885-
slots.to(torch.int64),
886-
kv_cache[1],
887-
kv_cache[0],
888-
epsilon=self.kv_a_layernorm.variance_epsilon,
889-
cache_mode=cache_mode,
890-
)
891-
return k_pe, k_nope
881+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
882+
kv,
883+
self.kv_a_layernorm.weight,
884+
cos,
885+
sin,
886+
slots.to(torch.int64),
887+
kv_cache[1],
888+
kv_cache[0],
889+
epsilon=self.kv_a_layernorm.variance_epsilon,
890+
cache_mode=cache_mode,
891+
)
892+
return k_pe, k_nope, kv
892893

893894
def exec_kv_prefill(
894895
self,
@@ -940,6 +941,7 @@ def _forward_decode(
940941
k_pe: torch.Tensor,
941942
kv_c_and_k_pe_cache: torch.Tensor,
942943
attn_metadata: AscendMLAMetadata,
944+
enable_multistream_mla: bool = False,
943945
) -> torch.Tensor:
944946
decode_meta = attn_metadata.decode
945947
assert decode_meta is not None
@@ -1020,7 +1022,8 @@ def _forward_decode(
10201022
out=attn_output)
10211023
current_ms_metadata = get_multistream_comm_context()
10221024
if current_ms_metadata is None:
1023-
return self._v_up_proj_and_o_proj(attn_output)
1025+
return self._v_up_proj_and_o_proj(attn_output,
1026+
enable_multistream_mla)
10241027
else:
10251028
current_ms_metadata.before_comm_event.record()
10261029
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -1037,6 +1040,7 @@ def forward(
10371040
attn_metadata: M,
10381041
output: Optional[torch.Tensor] = None,
10391042
enable_multistream_mla: bool = False,
1043+
ckq: Optional[torch.Tensor] = None,
10401044
) -> torch.Tensor:
10411045
assert output is not None, "Output tensor must be provided."
10421046
if attn_metadata is None:
@@ -1091,6 +1095,15 @@ def forward(
10911095
sin = sin[attn_metadata.decode.input_positions]
10921096
cos = cos[:, None, None, :]
10931097
sin = sin[:, None, None, :]
1098+
with npu_stream_switch("mla_secondary",
1099+
0,
1100+
enabled=enable_multistream_mla):
1101+
npu_wait_tensor(hidden_states_or_kv_c_normed,
1102+
ckq,
1103+
enabled=enable_multistream_mla)
1104+
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(
1105+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1106+
attn_metadata.slot_mapping)
10941107
# Without explicitly controlling the order, IndexByTensor operations
10951108
# would be placed after `matmul W_KV_T` hindering the overlapping of
10961109
# KvRmsNormRopeCache and SingleRope.
@@ -1100,12 +1113,13 @@ def forward(
11001113
npu_wait_tensor(decode_hs_or_q_c,
11011114
sin,
11021115
enabled=enable_multistream_mla)
1116+
npu_wait_tensor(decode_hs_or_q_c,
1117+
decode_kv,
1118+
enabled=enable_multistream_mla)
1119+
11031120
decode_ql_nope, decode_q_pe = \
11041121
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
11051122
if self.running_in_graph:
1106-
decode_k_pe, decode_k_nope = self.exec_kv(
1107-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1108-
attn_metadata.slot_mapping, enable_multistream_mla)
11091123
with npu_stream_switch("mla_secondary",
11101124
0,
11111125
enabled=enable_multistream_mla):
@@ -1194,7 +1208,8 @@ def forward(
11941208
if self.running_in_graph:
11951209
return self._forward_decode(decode_ql_nope, decode_q_pe,
11961210
decode_k_nope, decode_k_pe,
1197-
kv_cache, attn_metadata)
1211+
kv_cache, attn_metadata,
1212+
enable_multistream_mla)
11981213
else:
11991214
output_decode = self._forward_decode(decode_ql_nope,
12001215
decode_q_pe,

vllm_ascend/models/deepseek_v2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@
7474
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7575
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7676
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
77-
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
78-
npu_wait_tensor)
77+
from vllm_ascend.utils import dispose_tensor, npu_prefetch
7978

8079

8180
class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -567,12 +566,12 @@ def forward(
567566
and attn_metadata.num_decodes > 0)
568567
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
569568
if self.q_lora_rank is not None:
569+
npu_prefetch(self.q_a_proj.weight,
570+
hidden_states,
571+
enabled=enable_multistream_mla)
570572
ckq = self.q_a_proj(hidden_states)[0]
571-
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
572-
with npu_stream_switch("mla_secondary",
573-
0,
574-
enabled=enable_multistream_mla):
575-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
573+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
574+
forward_kwargs['ckq'] = ckq
576575
else:
577576
hidden_states_or_q_c = hidden_states
578577
if self.torchair_graph_enabled:

vllm_ascend/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,20 @@ def npu_wait_tensor(self: torch.Tensor,
416416
return _npu_wait_tensor(self, dependency) if enabled else self
417417

418418

419+
# TODO(wxy): Move to ops module
420+
def npu_prefetch(input: torch.Tensor,
421+
dependency: torch.Tensor,
422+
max_size: int = 0,
423+
*,
424+
enabled: bool = True):
425+
if not enabled:
426+
return
427+
input_size = input.element_size() * input.numel()
428+
if max_size <= 0 or max_size > input_size:
429+
max_size = input_size
430+
torch_npu.npu_prefetch(input, dependency, max_size)
431+
432+
419433
# TODO(zzzzwwjj): move this into forward_context
420434
class FusedMoEState(Enum):
421435
AllGather = 0

0 commit comments

Comments
 (0)