Skip to content

Commit dcea3aa

Browse files
committed
Offload vector operations of MLA to another stream
With the expected overlaping being: ``` | cos/sin | | q_rmsnorm | | kv_norm_rope_cache | | q_rope | | matmul W_DQ | matmul W_DKV | matmul W_UQ | split | matmul W_KV_T | ``` Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted to False. Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1 parent daf341c commit dcea3aa

File tree

5 files changed

+54
-24
lines changed

5 files changed

+54
-24
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ The details of each config option are as follows:
3939
| Name | Type | Default | Description |
4040
| ---- | ---- | ------- | ----------- |
4141
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
42+
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
4243
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4344
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4445
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
6161
"enable_multistream_moe": True,
62+
"enable_multistream_mla": True,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
7980
1, 2, 4, 8
8081
]
8182
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
83+
assert ascend_config.torchair_graph_config.enable_multistream_mla
8284
assert ascend_config.torchair_graph_config.enable_multistream_moe
8385
assert ascend_config.ascend_scheduler_config.enabled
8486
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(self, torchair_graph_config):
5454
"graph_batch_sizes", [])
5555
self.graph_batch_sizes_init = torchair_graph_config.get(
5656
"graph_batch_sizes_init", False)
57+
self.enable_multistream_mla = torchair_graph_config.get(
58+
"enable_multistream_mla", False)
5759
self.enable_multistream_moe = torchair_graph_config.get(
5860
"enable_multistream_moe", False)
5961
self.enable_view_optimize = torchair_graph_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm_ascend.multistream.context import get_multistream_comm_context
1919
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2020
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
21+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
2122

2223
if TYPE_CHECKING:
2324
from vllm.v1.core.sched.output import SchedulerOutput
@@ -475,6 +476,9 @@ def __init__(
475476

476477
ascend_config = get_ascend_config()
477478
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
479+
self.enable_multistream_mla = \
480+
ascend_config.torchair_graph_config.enable_multistream_mla
481+
478482
# Adapt torch air graph mode with spec decoding.
479483
speculative_config = get_current_vllm_config().speculative_config
480484
if speculative_config is not None:
@@ -648,17 +652,20 @@ def exec_kv(
648652
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
649653
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
650654
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
651-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
652-
kv,
653-
self.kv_a_layernorm.weight,
654-
cos,
655-
sin,
656-
slots.to(torch.int64),
657-
kv_cache[1],
658-
kv_cache[0],
659-
epsilon=self.kv_a_layernorm.variance_epsilon,
660-
cache_mode="PA",
661-
)
655+
with npu_stream_switch("mla_secondary",
656+
0,
657+
enabled=self.enable_multistream_mla):
658+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
659+
kv,
660+
self.kv_a_layernorm.weight,
661+
cos,
662+
sin,
663+
slots.to(torch.int64),
664+
kv_cache[1],
665+
kv_cache[0],
666+
epsilon=self.kv_a_layernorm.variance_epsilon,
667+
cache_mode="PA",
668+
)
662669
return k_pe, k_nope
663670

664671
def rope_single(
@@ -813,20 +820,28 @@ def forward(
813820
decode_ql_nope, decode_q_pe = \
814821
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
815822
if self.running_in_graph:
816-
seq_len = self.rotary_emb.max_position_embeddings
817-
cos = self.rotary_emb.cos_cached[:seq_len].to(
818-
dtype=decode_q_pe.dtype)
819-
sin = self.rotary_emb.sin_cached[:seq_len].to(
820-
dtype=decode_q_pe.dtype)
821-
cos = cos[attn_metadata.decode.input_positions]
822-
sin = sin[attn_metadata.decode.input_positions]
823-
cos = cos[:, None, None, :]
824-
sin = sin[:, None, None, :]
825-
826-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
823+
with npu_stream_switch("mla_secondary",
824+
0,
825+
enabled=self.enable_multistream_mla):
826+
seq_len = self.rotary_emb.max_position_embeddings
827+
cos = self.rotary_emb.cos_cached[:seq_len].to(
828+
dtype=decode_q_pe.dtype)
829+
sin = self.rotary_emb.sin_cached[:seq_len].to(
830+
dtype=decode_q_pe.dtype)
831+
cos = cos[attn_metadata.decode.input_positions]
832+
sin = sin[attn_metadata.decode.input_positions]
833+
cos = cos[:, None, None, :]
834+
sin = sin[:, None, None, :]
827835
decode_k_pe, decode_k_nope = self.exec_kv(
828836
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
829837
attn_metadata.slot_mapping)
838+
with npu_stream_switch("mla_secondary",
839+
0,
840+
enabled=self.enable_multistream_mla):
841+
npu_wait_tensor(decode_q_pe,
842+
decode_k_pe,
843+
enabled=self.enable_multistream_mla)
844+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
830845
else:
831846
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
832847
attn_metadata.decode.input_positions,

vllm_ascend/models/deepseek_v2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@
7070
from vllm_ascend.distributed.parallel_state import get_ep_group
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7272
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
73-
from vllm_ascend.utils import dispose_tensor
73+
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
74+
npu_wait_tensor)
7475

7576
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7677

@@ -488,6 +489,8 @@ def __init__(
488489

489490
ascend_config = get_ascend_config()
490491
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
492+
self.enable_multistream_mla = \
493+
ascend_config.torchair_graph_config.enable_multistream_mla
491494

492495
def forward(
493496
self,
@@ -497,7 +500,14 @@ def forward(
497500
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
498501
if self.q_lora_rank is not None:
499502
ckq = self.q_a_proj(hidden_states)[0]
500-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
503+
use_multistream_mla = (self.enable_multistream_mla
504+
and attn_metadata is not None
505+
and attn_metadata.num_decodes > 0)
506+
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
507+
with npu_stream_switch("mla_secondary",
508+
0,
509+
enabled=use_multistream_mla):
510+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
501511
else:
502512
hidden_states_or_q_c = hidden_states
503513
if self.torchair_graph_enabled:

0 commit comments

Comments
 (0)