Skip to content

Commit 9a5e650

Browse files
authored
[0.9.1][Perf] Port MLA multistream optimazition and prefetch to v0.9.1 (#1750)
This PR port the optimization in PR #1353 to v0.9.1-dev. Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent a9da140 commit 9a5e650

File tree

3 files changed

+57
-46
lines changed

3 files changed

+57
-46
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from vllm_ascend.multistream.context import get_multistream_comm_context
2323
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
2424
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
25-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_stream_switch,
26-
npu_wait_tensor)
25+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch,
26+
npu_stream_switch, npu_wait_tensor)
2727

2828
if TYPE_CHECKING:
2929
from vllm.v1.core.sched.output import SchedulerOutput
@@ -627,22 +627,25 @@ def __init__(
627627
ascend_config = get_ascend_config()
628628
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
629629
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
630-
self.enable_multistream_mla = \
631-
ascend_config.torchair_graph_config.enable_multistream_mla
632630

633631
# Adapt torch air graph mode with spec decoding.
634632
speculative_config = get_current_vllm_config().speculative_config
635633
if speculative_config is not None:
636634
self.spec_token_num = speculative_config.num_speculative_tokens
637635
assert self.spec_token_num > 0
638636

639-
def _v_up_proj_and_o_proj(self, x):
637+
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
640638
# Convert from (B, N, L) to (N, B, L)
641639
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
642640
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
643641
x = torch.bmm(x, self.W_UV)
644642
# Convert from (N, B, V) to (B, N * V)
645643
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
644+
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
645+
npu_prefetch(self.o_proj.weight,
646+
x,
647+
max_size=MAX_O_PROJ_PREFETCH_SIZE,
648+
enabled=enable_multistream_mla)
646649
return self.o_proj(x)[0]
647650

648651
# Return `ql_nope`, `q_pe`
@@ -933,20 +936,17 @@ def exec_kv(
933936
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
934937
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
935938
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
936-
with npu_stream_switch("mla_secondary",
937-
0,
938-
enabled=self.enable_multistream_mla):
939-
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
940-
kv,
941-
self.kv_a_layernorm.weight,
942-
cos,
943-
sin,
944-
slots.to(torch.int64),
945-
kv_cache[1],
946-
kv_cache[0],
947-
epsilon=self.kv_a_layernorm.variance_epsilon,
948-
cache_mode=cache_mode,
949-
)
939+
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
940+
kv,
941+
self.kv_a_layernorm.weight,
942+
cos,
943+
sin,
944+
slots.to(torch.int64),
945+
kv_cache[1],
946+
kv_cache[0],
947+
epsilon=self.kv_a_layernorm.variance_epsilon,
948+
cache_mode=cache_mode,
949+
)
950950
return k_pe, k_nope
951951

952952
def exec_kv_prefill(
@@ -999,6 +999,7 @@ def _forward_decode(
999999
k_pe: torch.Tensor,
10001000
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
10011001
attn_metadata: AscendMLAMetadata,
1002+
enable_multistream_mla: bool = False,
10021003
) -> torch.Tensor:
10031004
decode_meta = attn_metadata.decode
10041005
assert decode_meta is not None
@@ -1093,7 +1094,8 @@ def _forward_decode(
10931094
out=attn_output)
10941095
current_ms_metadata = get_multistream_comm_context()
10951096
if current_ms_metadata is None:
1096-
return self._v_up_proj_and_o_proj(attn_output)
1097+
return self._v_up_proj_and_o_proj(attn_output,
1098+
enable_multistream_mla)
10971099
else:
10981100
current_ms_metadata.before_comm_event.record()
10991101
with torch.npu.stream(current_ms_metadata.comm_stream):
@@ -1109,6 +1111,7 @@ def forward(
11091111
kv_cache: Tuple[torch.Tensor],
11101112
attn_metadata: M,
11111113
output: Optional[torch.Tensor] = None,
1114+
enable_multistream_mla=False,
11121115
) -> torch.Tensor:
11131116
assert output is not None, "Output tensor must be provided."
11141117
if attn_metadata is None:
@@ -1158,27 +1161,21 @@ def forward(
11581161
if self.running_in_graph:
11591162
cos = attn_metadata.decode.cos
11601163
sin = attn_metadata.decode.sin
1161-
# Without explicitly controlling the order, IndexByTensor operations
1162-
# would be placed after `matmul W_KV_T` hindering the overlapping of
1163-
# KvRmsNormRopeCache and SingleRope.
1164-
npu_wait_tensor(decode_hs_or_q_c,
1165-
cos,
1166-
enabled=self.enable_multistream_mla)
1167-
npu_wait_tensor(decode_hs_or_q_c,
1168-
sin,
1169-
enabled=self.enable_multistream_mla)
1164+
with npu_stream_switch("mla_secondary",
1165+
0,
1166+
enabled=enable_multistream_mla):
1167+
decode_k_pe, decode_k_nope = self.exec_kv(
1168+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1169+
attn_metadata.slot_mapping)
11701170
decode_ql_nope, decode_q_pe = \
11711171
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
11721172
if self.running_in_graph:
1173-
decode_k_pe, decode_k_nope = self.exec_kv(
1174-
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1175-
attn_metadata.slot_mapping)
11761173
with npu_stream_switch("mla_secondary",
11771174
0,
1178-
enabled=self.enable_multistream_mla):
1175+
enabled=enable_multistream_mla):
11791176
npu_wait_tensor(decode_q_pe,
11801177
decode_k_pe,
1181-
enabled=self.enable_multistream_mla)
1178+
enabled=enable_multistream_mla)
11821179
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
11831180
else:
11841181
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
@@ -1253,7 +1250,8 @@ def forward(
12531250
if self.running_in_graph:
12541251
return self._forward_decode(decode_ql_nope, decode_q_pe,
12551252
decode_k_nope, decode_k_pe,
1256-
kv_cache, attn_metadata)
1253+
kv_cache, attn_metadata,
1254+
enable_multistream_mla)
12571255
else:
12581256
output_decode = self._forward_decode(decode_ql_nope,
12591257
decode_q_pe,

vllm_ascend/models/deepseek_v2.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@
6868
from vllm_ascend.ops.fused_moe import AscendFusedMoE
6969
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7070
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
71-
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
72-
npu_wait_tensor)
71+
from vllm_ascend.utils import dispose_tensor, npu_prefetch
7372

7473

7574
class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -472,21 +471,22 @@ def forward(
472471
hidden_states: torch.Tensor,
473472
kv_cache: Optional[torch.Tensor] = None,
474473
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
474+
forward_context = get_forward_context()
475+
enable_multistream_mla = (self.enable_multistream_mla
476+
and attn_metadata is not None
477+
and not forward_context.with_prefill
478+
and attn_metadata.num_decodes > 0)
479+
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
475480
if self.q_lora_rank is not None:
481+
npu_prefetch(self.q_a_proj.weight,
482+
hidden_states,
483+
enabled=enable_multistream_mla)
476484
ckq = self.q_a_proj(hidden_states)[0]
477-
use_multistream_mla = (self.enable_multistream_mla
478-
and attn_metadata is not None
479-
and attn_metadata.num_decodes > 0)
480-
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
481-
with npu_stream_switch("mla_secondary",
482-
0,
483-
enabled=use_multistream_mla):
484-
hidden_states_or_q_c = self.q_a_layernorm(ckq)
485+
hidden_states_or_q_c = self.q_a_layernorm(ckq)
485486
else:
486487
hidden_states_or_q_c = hidden_states
487488
is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
488489
if self.torchair_graph_enabled and not is_mtp_model:
489-
forward_kwargs = {}
490490
if envs.VLLM_USE_V1:
491491
output_shape = hidden_states.shape
492492
output = torch.empty(output_shape,

vllm_ascend/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,19 @@ def npu_wait_tensor(self: torch.Tensor,
303303
return _npu_wait_tensor(self, dependency) if enabled else self
304304

305305

306+
def npu_prefetch(input: torch.Tensor,
307+
dependency: torch.Tensor,
308+
max_size: int = 0,
309+
*,
310+
enabled: bool = True):
311+
if not enabled:
312+
return
313+
input_size = input.element_size() * input.numel()
314+
if max_size <= 0 or max_size > input_size:
315+
max_size = input_size
316+
torch_npu.npu_prefetch(input, dependency, max_size)
317+
318+
306319
class AscendSocVersion(Enum):
307320
A2 = 0
308321
A3 = 1

0 commit comments

Comments
 (0)