|
17 | 17 | from vllm_ascend.multistream.context import get_multistream_comm_context
|
18 | 18 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
19 | 19 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 20 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
20 | 21 |
|
21 | 22 | if TYPE_CHECKING:
|
22 | 23 | from vllm.v1.core.sched.output import SchedulerOutput
|
@@ -461,6 +462,8 @@ def __init__(
|
461 | 462 |
|
462 | 463 | ascend_config = get_ascend_config()
|
463 | 464 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
| 465 | + self.enable_multistream_mla = \ |
| 466 | + ascend_config.torchair_graph_config.enable_multistream_mla |
464 | 467 |
|
465 | 468 | def _v_up_proj_and_o_proj(self, x):
|
466 | 469 | # Convert from (B, N, L) to (N, B, L)
|
@@ -626,17 +629,19 @@ def exec_kv(
|
626 | 629 | kv = self.kv_a_proj_with_mqa(hidden_states)[0]
|
627 | 630 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
628 | 631 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
629 |
| - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
630 |
| - kv, |
631 |
| - self.kv_a_layernorm.weight, |
632 |
| - cos, |
633 |
| - sin, |
634 |
| - slots.to(torch.int64), |
635 |
| - kv_cache[1], |
636 |
| - kv_cache[0], |
637 |
| - epsilon=self.kv_a_layernorm.variance_epsilon, |
638 |
| - cache_mode="PA", |
639 |
| - ) |
| 632 | + with npu_stream_switch("mla_secondary", 0, |
| 633 | + enabled=self.enable_multistream_mla): |
| 634 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 635 | + kv, |
| 636 | + self.kv_a_layernorm.weight, |
| 637 | + cos, |
| 638 | + sin, |
| 639 | + slots.to(torch.int64), |
| 640 | + kv_cache[1], |
| 641 | + kv_cache[0], |
| 642 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 643 | + cache_mode="PA", |
| 644 | + ) |
640 | 645 | return k_pe, k_nope
|
641 | 646 |
|
642 | 647 | def rope_single(
|
@@ -769,20 +774,25 @@ def forward(
|
769 | 774 | decode_ql_nope, decode_q_pe = \
|
770 | 775 | self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
771 | 776 | if self.running_in_graph:
|
772 |
| - seq_len = self.rotary_emb.max_position_embeddings |
773 |
| - cos = self.rotary_emb.cos_cached[:seq_len].to( |
774 |
| - dtype=decode_q_pe.dtype) |
775 |
| - sin = self.rotary_emb.sin_cached[:seq_len].to( |
776 |
| - dtype=decode_q_pe.dtype) |
777 |
| - cos = cos[attn_metadata.decode.input_positions] |
778 |
| - sin = sin[attn_metadata.decode.input_positions] |
779 |
| - cos = cos[:, None, None, :] |
780 |
| - sin = sin[:, None, None, :] |
781 |
| - |
782 |
| - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
| 777 | + with npu_stream_switch("mla_secondary", 0, |
| 778 | + enabled=self.enable_multistream_mla): |
| 779 | + seq_len = self.rotary_emb.max_position_embeddings |
| 780 | + cos = self.rotary_emb.cos_cached[:seq_len].to( |
| 781 | + dtype=decode_q_pe.dtype) |
| 782 | + sin = self.rotary_emb.sin_cached[:seq_len].to( |
| 783 | + dtype=decode_q_pe.dtype) |
| 784 | + cos = cos[attn_metadata.decode.input_positions] |
| 785 | + sin = sin[attn_metadata.decode.input_positions] |
| 786 | + cos = cos[:, None, None, :] |
| 787 | + sin = sin[:, None, None, :] |
783 | 788 | decode_k_pe, decode_k_nope = self.exec_kv(
|
784 | 789 | hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
785 | 790 | attn_metadata.slot_mapping)
|
| 791 | + with npu_stream_switch("mla_secondary", 0, |
| 792 | + enabled=self.enable_multistream_mla): |
| 793 | + npu_wait_tensor(decode_q_pe, decode_k_pe, |
| 794 | + self.enable_multistream_mla) |
| 795 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
786 | 796 | else:
|
787 | 797 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
788 | 798 | attn_metadata.decode.input_positions,
|
|
0 commit comments