|
20 | 20 | from vllm_ascend.multistream.context import get_multistream_comm_context
|
21 | 21 | from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
22 | 22 | from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
| 23 | +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor |
23 | 24 |
|
24 | 25 | if TYPE_CHECKING:
|
25 | 26 | from vllm.v1.core.sched.output import SchedulerOutput
|
@@ -557,6 +558,9 @@ def __init__(
|
557 | 558 | ascend_config = get_ascend_config()
|
558 | 559 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
559 | 560 | self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
| 561 | + self.enable_multistream_mla = \ |
| 562 | + ascend_config.torchair_graph_config.enable_multistream_mla |
| 563 | + |
560 | 564 | # Adapt torch air graph mode with spec decoding.
|
561 | 565 | speculative_config = get_current_vllm_config().speculative_config
|
562 | 566 | if speculative_config is not None:
|
@@ -861,17 +865,20 @@ def exec_kv(
|
861 | 865 | # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
|
862 | 866 | kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
|
863 | 867 | cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
|
864 |
| - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
865 |
| - kv, |
866 |
| - self.kv_a_layernorm.weight, |
867 |
| - cos, |
868 |
| - sin, |
869 |
| - slots.to(torch.int64), |
870 |
| - kv_cache[1], |
871 |
| - kv_cache[0], |
872 |
| - epsilon=self.kv_a_layernorm.variance_epsilon, |
873 |
| - cache_mode=cache_mode, |
874 |
| - ) |
| 868 | + with npu_stream_switch("mla_secondary", |
| 869 | + 0, |
| 870 | + enabled=self.enable_multistream_mla): |
| 871 | + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( |
| 872 | + kv, |
| 873 | + self.kv_a_layernorm.weight, |
| 874 | + cos, |
| 875 | + sin, |
| 876 | + slots.to(torch.int64), |
| 877 | + kv_cache[1], |
| 878 | + kv_cache[0], |
| 879 | + epsilon=self.kv_a_layernorm.variance_epsilon, |
| 880 | + cache_mode=cache_mode, |
| 881 | + ) |
875 | 882 | return k_pe, k_nope
|
876 | 883 |
|
877 | 884 | def exec_kv_prefill(
|
@@ -1064,23 +1071,38 @@ def forward(
|
1064 | 1071 | if has_decode:
|
1065 | 1072 | decode_k_nope = None
|
1066 | 1073 | assert attn_metadata.decode is not None
|
1067 |
| - decode_ql_nope, decode_q_pe = \ |
1068 |
| - self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
1069 | 1074 | if self.running_in_graph:
|
1070 | 1075 | seq_len = self.rotary_emb.max_position_embeddings
|
1071 | 1076 | cos = self.rotary_emb.cos_cached[:seq_len].to(
|
1072 |
| - dtype=decode_q_pe.dtype) |
| 1077 | + dtype=decode_hs_or_q_c.dtype) |
1073 | 1078 | sin = self.rotary_emb.sin_cached[:seq_len].to(
|
1074 |
| - dtype=decode_q_pe.dtype) |
| 1079 | + dtype=decode_hs_or_q_c.dtype) |
1075 | 1080 | cos = cos[attn_metadata.decode.input_positions]
|
1076 | 1081 | sin = sin[attn_metadata.decode.input_positions]
|
1077 | 1082 | cos = cos[:, None, None, :]
|
1078 | 1083 | sin = sin[:, None, None, :]
|
1079 |
| - |
1080 |
| - decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
| 1084 | + # Without explicitly controlling the order, IndexByTensor operations |
| 1085 | + # would be placed after `matmul W_KV_T` hindering the overlapping of |
| 1086 | + # KvRmsNormRopeCache and SingleRope. |
| 1087 | + npu_wait_tensor(decode_hs_or_q_c, |
| 1088 | + cos, |
| 1089 | + enabled=self.enable_multistream_mla) |
| 1090 | + npu_wait_tensor(decode_hs_or_q_c, |
| 1091 | + sin, |
| 1092 | + enabled=self.enable_multistream_mla) |
| 1093 | + decode_ql_nope, decode_q_pe = \ |
| 1094 | + self._q_proj_and_k_up_proj(decode_hs_or_q_c) |
| 1095 | + if self.running_in_graph: |
1081 | 1096 | decode_k_pe, decode_k_nope = self.exec_kv(
|
1082 | 1097 | hidden_states_or_kv_c_normed, cos, sin, kv_cache,
|
1083 | 1098 | attn_metadata.slot_mapping)
|
| 1099 | + with npu_stream_switch("mla_secondary", |
| 1100 | + 0, |
| 1101 | + enabled=self.enable_multistream_mla): |
| 1102 | + npu_wait_tensor(decode_q_pe, |
| 1103 | + decode_k_pe, |
| 1104 | + enabled=self.enable_multistream_mla) |
| 1105 | + decode_q_pe = self.rope_single(decode_q_pe, cos, sin) |
1084 | 1106 | else:
|
1085 | 1107 | decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
1086 | 1108 | attn_metadata.decode.input_positions,
|
|
0 commit comments