Skip to content

Commit 6fec4ef

Browse files
sdmyzlpwangxiaoxin (A)
authored andcommitted
Support multistream of MLA vector operations (#1135)
### What this PR does / why we need it? Move all vector operations to a secondary stream, with the expected overlaping being: ``` | q_rmsnorm | | kv_norm_rope_cache | | q_rope | | matmul W_DQ | matmul W_DKV | index | index | matmul W_UQ | split | matmul W_KV_T | ``` Currently, the `IndexByTensor` operators introduced by computation of `cos` and `sin` can't be offloaded to the secondary stream due to a known bug of graph fusion optimization pass. So we instead keep it in the main stream, only requires it be computed before `matmul W_UQ` to avoid hindering later overlapping. The problem may be solved by later optimization (#993), which hoists the computation of `cos` and `sin` up to the first layer. ### Does this PR introduce _any_ user-facing change? Controlled by `torchair_graph_config.enable_multistream_mla`, defaulted to False. ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. Signed-off-by: sdmyzlp <lrwei2@petalmail.com> Signed-off-by: wangxiaoxin (A) <wangxiaoxin7@huawei.com>
1 parent bb27b25 commit 6fec4ef

File tree

5 files changed

+56
-19
lines changed

5 files changed

+56
-19
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ The details of each config option are as follows:
4040
| Name | Type | Default | Description |
4141
| ---- | ---- | ------- | ----------- |
4242
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
43+
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
4344
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4445
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4546
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

tests/singlecard/test_ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_run_with_ascend_config():
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
6161
"enable_multistream_moe": True,
62+
"enable_multistream_mla": True,
6263
},
6364
"ascend_scheduler_config": {
6465
"enabled": True,
@@ -79,6 +80,7 @@ def test_run_with_ascend_config():
7980
1, 2, 4, 8
8081
]
8182
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
83+
assert ascend_config.torchair_graph_config.enable_multistream_mla
8284
assert ascend_config.torchair_graph_config.enable_multistream_moe
8385
assert ascend_config.ascend_scheduler_config.enabled
8486
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __init__(self, torchair_graph_config):
5656
"graph_batch_sizes", [])
5757
self.graph_batch_sizes_init = torchair_graph_config.get(
5858
"graph_batch_sizes_init", False)
59+
self.enable_multistream_mla = torchair_graph_config.get(
60+
"enable_multistream_mla", False)
5961
self.enable_multistream_moe = torchair_graph_config.get(
6062
"enable_multistream_moe", False)
6163
self.enable_view_optimize = torchair_graph_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm_ascend.multistream.context import get_multistream_comm_context
2121
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2222
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
23+
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
2324

2425
if TYPE_CHECKING:
2526
from vllm.v1.core.sched.output import SchedulerOutput
@@ -557,6 +558,9 @@ def __init__(
557558
ascend_config = get_ascend_config()
558559
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
559560
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
561+
self.enable_multistream_mla = \
562+
ascend_config.torchair_graph_config.enable_multistream_mla
563+
560564
# Adapt torch air graph mode with spec decoding.
561565
speculative_config = get_current_vllm_config().speculative_config
562566
if speculative_config is not None:
@@ -861,17 +865,20 @@ def exec_kv(
861865
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
862866
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
863867
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
864-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
865-
kv,
866-
self.kv_a_layernorm.weight,
867-
cos,
868-
sin,
869-
slots.to(torch.int64),
870-
kv_cache[1],
871-
kv_cache[0],
872-
epsilon=self.kv_a_layernorm.variance_epsilon,
873-
cache_mode=cache_mode,
874-
)
868+
with npu_stream_switch("mla_secondary",
869+
0,
870+
enabled=self.enable_multistream_mla):
871+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
872+
kv,
873+
self.kv_a_layernorm.weight,
874+
cos,
875+
sin,
876+
slots.to(torch.int64),
877+
kv_cache[1],
878+
kv_cache[0],
879+
epsilon=self.kv_a_layernorm.variance_epsilon,
880+
cache_mode=cache_mode,
881+
)
875882
return k_pe, k_nope
876883

877884
def exec_kv_prefill(
@@ -1064,23 +1071,38 @@ def forward(
10641071
if has_decode:
10651072
decode_k_nope = None
10661073
assert attn_metadata.decode is not None
1067-
decode_ql_nope, decode_q_pe = \
1068-
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
10691074
if self.running_in_graph:
10701075
seq_len = self.rotary_emb.max_position_embeddings
10711076
cos = self.rotary_emb.cos_cached[:seq_len].to(
1072-
dtype=decode_q_pe.dtype)
1077+
dtype=decode_hs_or_q_c.dtype)
10731078
sin = self.rotary_emb.sin_cached[:seq_len].to(
1074-
dtype=decode_q_pe.dtype)
1079+
dtype=decode_hs_or_q_c.dtype)
10751080
cos = cos[attn_metadata.decode.input_positions]
10761081
sin = sin[attn_metadata.decode.input_positions]
10771082
cos = cos[:, None, None, :]
10781083
sin = sin[:, None, None, :]
1079-
1080-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
1084+
# Without explicitly controlling the order, IndexByTensor operations
1085+
# would be placed after `matmul W_KV_T` hindering the overlapping of
1086+
# KvRmsNormRopeCache and SingleRope.
1087+
npu_wait_tensor(decode_hs_or_q_c,
1088+
cos,
1089+
enabled=self.enable_multistream_mla)
1090+
npu_wait_tensor(decode_hs_or_q_c,
1091+
sin,
1092+
enabled=self.enable_multistream_mla)
1093+
decode_ql_nope, decode_q_pe = \
1094+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
1095+
if self.running_in_graph:
10811096
decode_k_pe, decode_k_nope = self.exec_kv(
10821097
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
10831098
attn_metadata.slot_mapping)
1099+
with npu_stream_switch("mla_secondary",
1100+
0,
1101+
enabled=self.enable_multistream_mla):
1102+
npu_wait_tensor(decode_q_pe,
1103+
decode_k_pe,
1104+
enabled=self.enable_multistream_mla)
1105+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
10841106
else:
10851107
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
10861108
attn_metadata.decode.input_positions,

vllm_ascend/models/deepseek_v2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7272
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7373
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
74-
from vllm_ascend.utils import dispose_tensor
74+
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
75+
npu_wait_tensor)
7576

7677
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7778

@@ -496,6 +497,8 @@ def __init__(
496497

497498
ascend_config = get_ascend_config()
498499
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
500+
self.enable_multistream_mla = \
501+
ascend_config.torchair_graph_config.enable_multistream_mla
499502

500503
def forward(
501504
self,
@@ -505,7 +508,14 @@ def forward(
505508
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
506509
if self.q_lora_rank is not None:
507510
ckq = self.q_a_proj(hidden_states)[0]
508-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
511+
use_multistream_mla = (self.enable_multistream_mla
512+
and attn_metadata is not None
513+
and attn_metadata.num_decodes > 0)
514+
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
515+
with npu_stream_switch("mla_secondary",
516+
0,
517+
enabled=use_multistream_mla):
518+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
509519
else:
510520
hidden_states_or_q_c = hidden_states
511521
if self.torchair_graph_enabled:

0 commit comments

Comments
 (0)