|
18 | 18 | from vllm_ascend.multistream.context import get_multistream_comm_context
|
19 | 19 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
20 | 20 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 21 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
21 | 22 |
|
22 | 23 | if TYPE_CHECKING:
|
23 | 24 | from vllm.v1.core.sched.output import SchedulerOutput
|
@@ -475,6 +476,9 @@ def __init__(
|
475 | 476 |
|
476 | 477 | ascend_config = get_ascend_config()
|
477 | 478 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 479 | + self.enable_multistream_mla = \ |
| 480 | + ascend_config.torchair_graph_config.enable_multistream_mla |
| 481 | + |
478 | 482 | # Adapt torch air graph mode with spec decoding.
|
479 | 483 | speculative_config = get_current_vllm_config().speculative_config
|
480 | 484 | if speculative_config is not None:
|
@@ -648,17 +652,20 @@ def exec_kv(
|
648 | 652 | kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
649 | 653 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
650 | 654 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
651 |
| - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
652 |
| - kv, |
653 |
| - self.kv_a_layernorm.weight, |
654 |
| - cos, |
655 |
| - sin, |
656 |
| - slots.to(torch.int64), |
657 |
| - kv_cache[1], |
658 |
| - kv_cache[0], |
659 |
| - epsilon=self.kv_a_layernorm.variance_epsilon, |
660 |
| - cache_mode="PA", |
661 |
| - ) |
| 655 | + with npu_stream_switch("mla_secondary", |
| 656 | + 0, |
| 657 | + enabled=self.enable_multistream_mla): |
| 658 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 659 | + kv, |
| 660 | + self.kv_a_layernorm.weight, |
| 661 | + cos, |
| 662 | + sin, |
| 663 | + slots.to(torch.int64), |
| 664 | + kv_cache[1], |
| 665 | + kv_cache[0], |
| 666 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 667 | + cache_mode="PA", |
| 668 | + ) |
662 | 669 | return k_pe, k_nope
|
663 | 670 |
|
664 | 671 | def rope_single(
|
@@ -813,20 +820,28 @@ def forward(
|
813 | 820 | decode_ql_nope, decode_q_pe = \
|
814 | 821 | self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
815 | 822 | if self.running_in_graph:
|
816 |
| - seq_len = self.rotary_emb.max_position_embeddings |
817 |
| - cos = self.rotary_emb.cos_cached[:seq_len].to( |
818 |
| - dtype=decode_q_pe.dtype) |
819 |
| - sin = self.rotary_emb.sin_cached[:seq_len].to( |
820 |
| - dtype=decode_q_pe.dtype) |
821 |
| - cos = cos[attn_metadata.decode.input_positions] |
822 |
| - sin = sin[attn_metadata.decode.input_positions] |
823 |
| - cos = cos[:, None, None, :] |
824 |
| - sin = sin[:, None, None, :] |
825 |
| - |
826 |
| - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
| 823 | + with npu_stream_switch("mla_secondary", |
| 824 | + 0, |
| 825 | + enabled=self.enable_multistream_mla): |
| 826 | + seq_len = self.rotary_emb.max_position_embeddings |
| 827 | + cos = self.rotary_emb.cos_cached[:seq_len].to( |
| 828 | + dtype=decode_q_pe.dtype) |
| 829 | + sin = self.rotary_emb.sin_cached[:seq_len].to( |
| 830 | + dtype=decode_q_pe.dtype) |
| 831 | + cos = cos[attn_metadata.decode.input_positions] |
| 832 | + sin = sin[attn_metadata.decode.input_positions] |
| 833 | + cos = cos[:, None, None, :] |
| 834 | + sin = sin[:, None, None, :] |
827 | 835 | decode_k_pe, decode_k_nope = self.exec_kv(
|
828 | 836 | hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
829 | 837 | attn_metadata.slot_mapping)
|
| 838 | + with npu_stream_switch("mla_secondary", |
| 839 | + 0, |
| 840 | + enabled=self.enable_multistream_mla): |
| 841 | + npu_wait_tensor(decode_q_pe, |
| 842 | + decode_k_pe, |
| 843 | + enabled=self.enable_multistream_mla) |
| 844 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
830 | 845 | else:
|
831 | 846 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
832 | 847 | attn_metadata.decode.input_positions,
|
|
0 commit comments