Skip to content

Commit 92155a3

Browse files
committed
Revert "Support multistream of MLA vector operations (vllm-project#1135)"
This reverts commit e72f94e.
1 parent e72f94e commit 92155a3

File tree

5 files changed

+19
-56
lines changed

5 files changed

+19
-56
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ The details of each config option are as follows:
3939
| Name | Type | Default | Description |
4040
| ---- | ---- | ------- | ----------- |
4141
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
42-
| `enable_multistream_mla`| bool | `False` | Whether to put vector ops of MLA to another stream |
4342
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
4443
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
4544
| `use_cached_graph` | bool | `False` | Whether to use cached graph |

tests/singlecard/test_ascend_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_run_with_ascend_config():
5959
"graph_batch_sizes": [1, 2, 4, 8],
6060
"graph_batch_sizes_init": False,
6161
"enable_multistream_moe": True,
62-
"enable_multistream_mla": True,
6362
},
6463
"ascend_scheduler_config": {
6564
"enabled": True,
@@ -80,7 +79,6 @@ def test_run_with_ascend_config():
8079
1, 2, 4, 8
8180
]
8281
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
83-
assert ascend_config.torchair_graph_config.enable_multistream_mla
8482
assert ascend_config.torchair_graph_config.enable_multistream_moe
8583
assert ascend_config.ascend_scheduler_config.enabled
8684
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def __init__(self, torchair_graph_config):
5454
"graph_batch_sizes", [])
5555
self.graph_batch_sizes_init = torchair_graph_config.get(
5656
"graph_batch_sizes_init", False)
57-
self.enable_multistream_mla = torchair_graph_config.get(
58-
"enable_multistream_mla", False)
5957
self.enable_multistream_moe = torchair_graph_config.get(
6058
"enable_multistream_moe", False)
6159
self.enable_view_optimize = torchair_graph_config.get(

vllm_ascend/attention/mla_v1.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from vllm_ascend.multistream.context import get_multistream_comm_context
2020
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2121
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
22-
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
2322

2423
if TYPE_CHECKING:
2524
from vllm.v1.core.sched.output import SchedulerOutput
@@ -482,9 +481,6 @@ def __init__(
482481
ascend_config = get_ascend_config()
483482
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
484483
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
485-
self.enable_multistream_mla = \
486-
ascend_config.torchair_graph_config.enable_multistream_mla
487-
488484
# Adapt torch air graph mode with spec decoding.
489485
speculative_config = get_current_vllm_config().speculative_config
490486
if speculative_config is not None:
@@ -668,20 +664,17 @@ def exec_kv(
668664
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
669665
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
670666
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
671-
with npu_stream_switch("mla_secondary",
672-
0,
673-
enabled=self.enable_multistream_mla):
674-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
675-
kv,
676-
self.kv_a_layernorm.weight,
677-
cos,
678-
sin,
679-
slots.to(torch.int64),
680-
kv_cache[1],
681-
kv_cache[0],
682-
epsilon=self.kv_a_layernorm.variance_epsilon,
683-
cache_mode=cache_mode,
684-
)
667+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
668+
kv,
669+
self.kv_a_layernorm.weight,
670+
cos,
671+
sin,
672+
slots.to(torch.int64),
673+
kv_cache[1],
674+
kv_cache[0],
675+
epsilon=self.kv_a_layernorm.variance_epsilon,
676+
cache_mode=cache_mode,
677+
)
685678
return k_pe, k_nope
686679

687680
def exec_kv_prefill(
@@ -874,38 +867,23 @@ def forward(
874867
if has_decode:
875868
decode_k_nope = None
876869
assert attn_metadata.decode is not None
870+
decode_ql_nope, decode_q_pe = \
871+
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
877872
if self.running_in_graph:
878873
seq_len = self.rotary_emb.max_position_embeddings
879874
cos = self.rotary_emb.cos_cached[:seq_len].to(
880-
dtype=decode_hs_or_q_c.dtype)
875+
dtype=decode_q_pe.dtype)
881876
sin = self.rotary_emb.sin_cached[:seq_len].to(
882-
dtype=decode_hs_or_q_c.dtype)
877+
dtype=decode_q_pe.dtype)
883878
cos = cos[attn_metadata.decode.input_positions]
884879
sin = sin[attn_metadata.decode.input_positions]
885880
cos = cos[:, None, None, :]
886881
sin = sin[:, None, None, :]
887-
# Without explicitly controlling the order, IndexByTensor operations
888-
# would be placed after `matmul W_KV_T` hindering the overlapping of
889-
# KvRmsNormRopeCache and SingleRope.
890-
npu_wait_tensor(decode_hs_or_q_c,
891-
cos,
892-
enabled=self.enable_multistream_mla)
893-
npu_wait_tensor(decode_hs_or_q_c,
894-
sin,
895-
enabled=self.enable_multistream_mla)
896-
decode_ql_nope, decode_q_pe = \
897-
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
898-
if self.running_in_graph:
882+
883+
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
899884
decode_k_pe, decode_k_nope = self.exec_kv(
900885
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
901886
attn_metadata.slot_mapping)
902-
with npu_stream_switch("mla_secondary",
903-
0,
904-
enabled=self.enable_multistream_mla):
905-
npu_wait_tensor(decode_q_pe,
906-
decode_k_pe,
907-
enabled=self.enable_multistream_mla)
908-
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
909887
else:
910888
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
911889
attn_metadata.decode.input_positions,

vllm_ascend/models/deepseek_v2.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@
7171
from vllm_ascend.ops.fused_moe import AscendFusedMoE
7272
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7373
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
74-
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
75-
npu_wait_tensor)
74+
from vllm_ascend.utils import dispose_tensor
7675

7776
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
7877

@@ -497,8 +496,6 @@ def __init__(
497496

498497
ascend_config = get_ascend_config()
499498
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
500-
self.enable_multistream_mla = \
501-
ascend_config.torchair_graph_config.enable_multistream_mla
502499

503500
def forward(
504501
self,
@@ -508,14 +505,7 @@ def forward(
508505
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
509506
if self.q_lora_rank is not None:
510507
ckq = self.q_a_proj(hidden_states)[0]
511-
use_multistream_mla = (self.enable_multistream_mla
512-
and attn_metadata is not None
513-
and attn_metadata.num_decodes > 0)
514-
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
515-
with npu_stream_switch("mla_secondary",
516-
0,
517-
enabled=use_multistream_mla):
518-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
508+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
519509
else:
520510
hidden_states_or_q_c = hidden_states
521511
if self.torchair_graph_enabled:

0 commit comments

Comments
 (0)