|
19 | 19 | from vllm_ascend.multistream.context import get_multistream_comm_context
|
20 | 20 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
21 | 21 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 22 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
22 | 23 |
|
23 | 24 | if TYPE_CHECKING:
|
24 | 25 | from vllm.v1.core.sched.output import SchedulerOutput
|
@@ -480,6 +481,9 @@ def __init__(
|
480 | 481 |
|
481 | 482 | ascend_config = get_ascend_config()
|
482 | 483 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 484 | + self.enable_multistream_mla = \ |
| 485 | + ascend_config.torchair_graph_config.enable_multistream_mla |
| 486 | + |
483 | 487 | # Adapt torch air graph mode with spec decoding.
|
484 | 488 | speculative_config = get_current_vllm_config().speculative_config
|
485 | 489 | if speculative_config is not None:
|
@@ -662,17 +666,20 @@ def exec_kv(
|
662 | 666 | kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
663 | 667 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
664 | 668 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
665 |
| - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
666 |
| - kv, |
667 |
| - self.kv_a_layernorm.weight, |
668 |
| - cos, |
669 |
| - sin, |
670 |
| - slots.to(torch.int64), |
671 |
| - kv_cache[1], |
672 |
| - kv_cache[0], |
673 |
| - epsilon=self.kv_a_layernorm.variance_epsilon, |
674 |
| - cache_mode="PA", |
675 |
| - ) |
| 669 | + with npu_stream_switch("mla_secondary", |
| 670 | + 0, |
| 671 | + enabled=self.enable_multistream_mla): |
| 672 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 673 | + kv, |
| 674 | + self.kv_a_layernorm.weight, |
| 675 | + cos, |
| 676 | + sin, |
| 677 | + slots.to(torch.int64), |
| 678 | + kv_cache[1], |
| 679 | + kv_cache[0], |
| 680 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 681 | + cache_mode="PA", |
| 682 | + ) |
676 | 683 | return k_pe, k_nope
|
677 | 684 |
|
678 | 685 | def rope_single(
|
@@ -824,23 +831,38 @@ def forward(
|
824 | 831 | if has_decode:
|
825 | 832 | decode_k_nope = None
|
826 | 833 | assert attn_metadata.decode is not None
|
827 |
| - decode_ql_nope, decode_q_pe = \ |
828 |
| - self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
829 | 834 | if self.running_in_graph:
|
830 | 835 | seq_len = self.rotary_emb.max_position_embeddings
|
831 | 836 | cos = self.rotary_emb.cos_cached[:seq_len].to(
|
832 |
| - dtype=decode_q_pe.dtype) |
| 837 | + dtype=decode_hs_or_q_c.dtype) |
833 | 838 | sin = self.rotary_emb.sin_cached[:seq_len].to(
|
834 |
| - dtype=decode_q_pe.dtype) |
| 839 | + dtype=decode_hs_or_q_c.dtype) |
835 | 840 | cos = cos[attn_metadata.decode.input_positions]
|
836 | 841 | sin = sin[attn_metadata.decode.input_positions]
|
837 | 842 | cos = cos[:, None, None, :]
|
838 | 843 | sin = sin[:, None, None, :]
|
839 |
| - |
840 |
| - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
| 844 | + # Without explicitly controlling the order, IndexByTensor operations |
| 845 | + # would be placed after `matmul W_KV_T` hindering the overlapping of |
| 846 | + # KvRmsNormRopeCache and SingleRope. |
| 847 | + npu_wait_tensor(decode_hs_or_q_c, |
| 848 | + cos, |
| 849 | + enabled=self.enable_multistream_mla) |
| 850 | + npu_wait_tensor(decode_hs_or_q_c, |
| 851 | + sin, |
| 852 | + enabled=self.enable_multistream_mla) |
| 853 | + decode_ql_nope, decode_q_pe = \ |
| 854 | + self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
| 855 | + if self.running_in_graph: |
841 | 856 | decode_k_pe, decode_k_nope = self.exec_kv(
|
842 | 857 | hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
843 | 858 | attn_metadata.slot_mapping)
|
| 859 | + with npu_stream_switch("mla_secondary", |
| 860 | + 0, |
| 861 | + enabled=self.enable_multistream_mla): |
| 862 | + npu_wait_tensor(decode_q_pe, |
| 863 | + decode_k_pe, |
| 864 | + enabled=self.enable_multistream_mla) |
| 865 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
844 | 866 | else:
|
845 | 867 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
846 | 868 | attn_metadata.decode.input_positions,
|
|
0 commit comments