Skip to content

Commit fd0f6fa

Browse files
committed
Offload vector operations of MLA to another stream
With the expected overlaping being: ``` | q_rmsnorm | | kv_norm_rope_cache | | q_rope | | matmul W_DQ | matmul W_DKV | matmul W_UQ | split | matmul W_KV_T | ``` Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted to False. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent 7bdc606 commit fd0f6fa

File tree

5 files changed

+56
-19
lines changed

5 files changed

+56
-19
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ The details of each config option are as follows:
3939
| Name | Type | Default | Description |
4040
| ---- | ---- | ------- | ----------- |
4141
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
42+
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
4243
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4344
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4445
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
6161
"enable_multistream_moe": True,
62+
"enable_multistream_mla": True,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
7980
1, 2, 4, 8
8081
]
8182
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
83+
assert ascend_config.torchair_graph_config.enable_multistream_mla
8284
assert ascend_config.torchair_graph_config.enable_multistream_moe
8385
assert ascend_config.ascend_scheduler_config.enabled
8486
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(self, torchair_graph_config):
5454
"graph_batch_sizes", [])
5555
self.graph_batch_sizes_init = torchair_graph_config.get(
5656
"graph_batch_sizes_init", False)
57+
self.enable_multistream_mla = torchair_graph_config.get(
58+
"enable_multistream_mla", False)
5759
self.enable_multistream_moe = torchair_graph_config.get(
5860
"enable_multistream_moe", False)
5961
self.enable_view_optimize = torchair_graph_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm_ascend.multistream.context import get_multistream_comm_context
2020
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2121
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
22+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
2223

2324
if TYPE_CHECKING:
2425
from vllm.v1.core.sched.output import SchedulerOutput
@@ -480,6 +481,9 @@ def __init__(
480481

481482
ascend_config = get_ascend_config()
482483
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
484+
self.enable_multistream_mla = \
485+
ascend_config.torchair_graph_config.enable_multistream_mla
486+
483487
# Adapt torch air graph mode with spec decoding.
484488
speculative_config = get_current_vllm_config().speculative_config
485489
if speculative_config is not None:
@@ -662,17 +666,20 @@ def exec_kv(
662666
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
663667
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
664668
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
665-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
666-
kv,
667-
self.kv_a_layernorm.weight,
668-
cos,
669-
sin,
670-
slots.to(torch.int64),
671-
kv_cache[1],
672-
kv_cache[0],
673-
epsilon=self.kv_a_layernorm.variance_epsilon,
674-
cache_mode="PA",
675-
)
669+
with npu_stream_switch("mla_secondary",
670+
0,
671+
enabled=self.enable_multistream_mla):
672+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
673+
kv,
674+
self.kv_a_layernorm.weight,
675+
cos,
676+
sin,
677+
slots.to(torch.int64),
678+
kv_cache[1],
679+
kv_cache[0],
680+
epsilon=self.kv_a_layernorm.variance_epsilon,
681+
cache_mode="PA",
682+
)
676683
return k_pe, k_nope
677684

678685
def rope_single(
@@ -824,23 +831,38 @@ def forward(
824831
if has_decode:
825832
decode_k_nope = None
826833
assert attn_metadata.decode is not None
827-
decode_ql_nope, decode_q_pe = \
828-
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
829834
if self.running_in_graph:
830835
seq_len = self.rotary_emb.max_position_embeddings
831836
cos = self.rotary_emb.cos_cached[:seq_len].to(
832-
dtype=decode_q_pe.dtype)
837+
dtype=decode_hs_or_q_c.dtype)
833838
sin = self.rotary_emb.sin_cached[:seq_len].to(
834-
dtype=decode_q_pe.dtype)
839+
dtype=decode_hs_or_q_c.dtype)
835840
cos = cos[attn_metadata.decode.input_positions]
836841
sin = sin[attn_metadata.decode.input_positions]
837842
cos = cos[:, None, None, :]
838843
sin = sin[:, None, None, :]
839-
840-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
844+
# Without explicitly controlling the order, IndexByTensor operations
845+
# would be placed after `matmul W_KV_T` hindering the overlapping of
846+
# KvRmsNormRopeCache and SingleRope.
847+
npu_wait_tensor(decode_hs_or_q_c,
848+
cos,
849+
enabled=self.enable_multistream_mla)
850+
npu_wait_tensor(decode_hs_or_q_c,
851+
sin,
852+
enabled=self.enable_multistream_mla)
853+
decode_ql_nope, decode_q_pe = \
854+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
855+
if self.running_in_graph:
841856
decode_k_pe, decode_k_nope = self.exec_kv(
842857
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
843858
attn_metadata.slot_mapping)
859+
with npu_stream_switch("mla_secondary",
860+
0,
861+
enabled=self.enable_multistream_mla):
862+
npu_wait_tensor(decode_q_pe,
863+
decode_k_pe,
864+
enabled=self.enable_multistream_mla)
865+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
844866
else:
845867
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
846868
attn_metadata.decode.input_positions,

vllm_ascend/models/deepseek_v2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7272
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7373
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
74-
from vllm_ascend.utils import dispose_tensor
74+
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
75+
npu_wait_tensor)
7576

7677
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7778

@@ -496,6 +497,8 @@ def __init__(
496497

497498
ascend_config = get_ascend_config()
498499
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
500+
self.enable_multistream_mla = \
501+
ascend_config.torchair_graph_config.enable_multistream_mla
499502

500503
def forward(
501504
self,
@@ -505,7 +508,14 @@ def forward(
505508
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
506509
if self.q_lora_rank is not None:
507510
ckq = self.q_a_proj(hidden_states)[0]
508-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
511+
use_multistream_mla = (self.enable_multistream_mla
512+
and attn_metadata is not None
513+
and attn_metadata.num_decodes > 0)
514+
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
515+
with npu_stream_switch("mla_secondary",
516+
0,
517+
enabled=use_multistream_mla):
518+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
509519
else:
510520
hidden_states_or_q_c = hidden_states
511521
if self.torchair_graph_enabled:

0 commit comments

Comments
 (0)