Skip to content

Commit bbb5981

Browse files
chenwanerwangxiaoxin (A)
authored andcommitted
Enable kvcache_nz for the decode process in torchair graph mode (#1098)
What this PR does / why we need it? Enable kvcache_nz for the decode process in torchair graph mode, which reduces the time consumed by FA in long sequences. Does this PR introduce any user-facing change? If need to enable kvcache_nz, should set the additional_config.torchair_graph_config.enable_kv_nz=True How was this patch tested? 1. Tested in deepseek model: with batchsize 64 and seq_len 1k+3k, 61 layers FA total time improves 20.80ms -> 19.76ms 2. operator precision test: [aclnnFusedInferAttentionScoreV3_result.csv](https://github.com/user-attachments/files/20664138/aclnnFusedInferAttentionScoreV3_result.csv) 3. tpot test from @ttanzhiqiang, and curl one result is normal #1098 (comment) #1098 (comment) --------- Signed-off-by: chenwaner <861645847@qq.com>
1 parent c03e12e commit bbb5981

File tree

3 files changed

+96
-47
lines changed

3 files changed

+96
-47
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ The details of each config option are as follows:
4545
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
4646
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
4747
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
48+
| `enable_kv_nz`| bool | `False` | Whether to enable kvcache NZ layout |
4849

4950
**ascend_scheduler_config**
5051

@@ -65,7 +66,8 @@ A full example of additional configuration is as follows:
6566
"use_cached_graph": true,
6667
"graph_batch_sizes": [1, 2, 4, 8],
6768
"graph_batch_sizes_init": false,
68-
"enable_multistream_moe": false
69+
"enable_multistream_moe": false,
70+
"enable_kv_nz": false
6971
},
7072
"ascend_scheduler_config": {
7173
"enabled": true,

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, torchair_graph_config):
6060
"enable_multistream_moe", False)
6161
self.enable_view_optimize = torchair_graph_config.get(
6262
"enable_view_optimize", True)
63+
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
6364

6465
if not isinstance(self.graph_batch_sizes, list):
6566
raise TypeError("graph_batch_sizes must be list[int]")

vllm_ascend/attention/mla_v1.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def __init__(
556556

557557
ascend_config = get_ascend_config()
558558
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
559+
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
559560
# Adapt torch air graph mode with spec decoding.
560561
speculative_config = get_current_vllm_config().speculative_config
561562
if speculative_config is not None:
@@ -859,6 +860,7 @@ def exec_kv(
859860
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
860861
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
861862
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
863+
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
862864
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
863865
kv,
864866
self.kv_a_layernorm.weight,
@@ -868,7 +870,37 @@ def exec_kv(
868870
kv_cache[1],
869871
kv_cache[0],
870872
epsilon=self.kv_a_layernorm.variance_epsilon,
871-
cache_mode="PA",
873+
cache_mode=cache_mode,
874+
)
875+
return k_pe, k_nope
876+
877+
def exec_kv_prefill(
878+
self,
879+
hidden_states: torch.Tensor,
880+
cos: torch.Tensor,
881+
sin: torch.Tensor,
882+
kv_cache: Tuple,
883+
slots: torch.Tensor,
884+
):
885+
886+
B = hidden_states.shape[0]
887+
N = self.num_kv_heads
888+
S = 1
889+
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
890+
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
891+
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
892+
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
893+
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
894+
kv,
895+
self.kv_a_layernorm.weight,
896+
cos,
897+
sin,
898+
slots.to(torch.int64),
899+
kv_cache[1],
900+
kv_cache[0],
901+
epsilon=self.kv_a_layernorm.variance_epsilon,
902+
cache_mode=cache_mode,
903+
is_output_kv=True,
872904
)
873905
return k_pe, k_nope
874906

@@ -906,42 +938,50 @@ def _forward_decode(
906938
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
907939
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
908940
assert num_tokens % self.spec_token_num == 0
909-
q_nope = (q_nope.view(
910-
num_tokens // (self.spec_token_num + 1),
911-
self.spec_token_num + 1,
912-
self.num_heads,
913-
-1,
914-
).transpose(1, 2).contiguous())
915-
q_pe = (q_pe.view(
916-
num_tokens // (self.spec_token_num + 1),
917-
self.spec_token_num + 1,
918-
self.num_heads,
919-
-1,
920-
).transpose(1, 2).contiguous())
941+
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
942+
self.spec_token_num + 1, self.num_heads,
943+
-1)
944+
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
945+
self.spec_token_num + 1, self.num_heads, -1)
946+
if not self.enable_kv_nz:
947+
q_nope = q_nope.transpose(1, 2).contiguous()
948+
q_pe = q_pe.transpose(1, 2).contiguous()
921949
sparse_mode = 3
922950
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
923951
else:
924-
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
925-
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
952+
if self.enable_kv_nz:
953+
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
954+
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
955+
else:
956+
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
957+
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
926958
sparse_mode = 0
927959
spec_attn_mask = None
928960
# shape of knope/k_pe for npu graph mode should be:
929961
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
930962
block_size = kv_c_and_k_pe_cache[0].shape[1]
931-
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
932-
self.kv_lora_rank)
933-
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
934-
self.qk_rope_head_dim)
963+
if self.enable_kv_nz:
964+
k_nope = k_nope.view(-1, self.num_kv_heads,
965+
self.kv_lora_rank // 16, block_size, 16)
966+
k_pe = k_pe.view(-1, self.num_kv_heads,
967+
self.qk_rope_head_dim // 16, block_size, 16)
968+
input_layout = "BSND"
969+
else:
970+
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
971+
self.kv_lora_rank)
972+
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
973+
self.qk_rope_head_dim)
974+
input_layout = "BNSD"
935975

936-
attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
976+
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
937977
q_nope,
938978
k_nope,
939979
k_nope,
940980
query_rope=q_pe,
941981
key_rope=k_pe,
942982
num_heads=self.num_heads,
943983
num_key_value_heads=self.num_kv_heads,
944-
input_layout="BNSD",
984+
input_layout=input_layout,
945985
atten_mask=spec_attn_mask,
946986
sparse_mode=sparse_mode,
947987
scale=self.scale,
@@ -990,10 +1030,11 @@ def forward(
9901030
]
9911031
num_actual_toks = attn_metadata.num_actual_tokens
9921032
if k_pe is None and not self.running_in_graph:
993-
kv_c, k_pe = self.kv_a_proj_with_mqa(
994-
hidden_states_or_kv_c_normed)[0].split(
995-
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
996-
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
1033+
if not self.torchair_graph_enabled:
1034+
kv_c, k_pe = self.kv_a_proj_with_mqa(
1035+
hidden_states_or_kv_c_normed)[0].split(
1036+
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1037+
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
9971038
else:
9981039
kv_c_normed = hidden_states_or_kv_c_normed
9991040
assert attn_metadata.num_decodes is not None and \
@@ -1006,16 +1047,18 @@ def forward(
10061047
# Inputs and outputs may be padded for CUDA graphs
10071048
output_padded = output
10081049
output = output[:num_actual_toks, ...]
1009-
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1010-
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
1050+
if not self.torchair_graph_enabled:
1051+
kv_c_normed = kv_c_normed[:num_actual_toks, ...]
1052+
prefill_k_c_normed = kv_c_normed[num_decode_tokens:]
10111053
if not self.running_in_graph:
10121054
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
1013-
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
10141055
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
1015-
k_pe = k_pe[:num_actual_toks, ...]
1016-
k_pe = k_pe.unsqueeze(1)
1017-
decode_k_pe = k_pe[:num_decode_tokens]
1018-
prefill_k_pe = k_pe[num_decode_tokens:]
1056+
if not self.torchair_graph_enabled:
1057+
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
1058+
k_pe = k_pe[:num_actual_toks, ...]
1059+
k_pe = k_pe.unsqueeze(1)
1060+
decode_k_pe = k_pe[:num_decode_tokens]
1061+
prefill_k_pe = k_pe[num_decode_tokens:]
10191062
else:
10201063
decode_hs_or_q_c = hidden_states_or_q_c
10211064
if has_decode:
@@ -1052,22 +1095,25 @@ def forward(
10521095
prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim]
10531096
if self.torchair_graph_enabled:
10541097
num_tokens = prefill_hs_or_q_c.shape[0]
1098+
seq_len = self.rotary_emb.max_position_embeddings
1099+
cos = self.rotary_emb.cos_cached[:seq_len].to(
1100+
dtype=prefill_q_pe.dtype)
1101+
sin = self.rotary_emb.sin_cached[:seq_len].to(
1102+
dtype=prefill_q_pe.dtype)
1103+
cos = cos[attn_metadata.prefill.input_positions]
1104+
sin = sin[attn_metadata.prefill.input_positions]
1105+
cos = cos[:, None, None, :]
1106+
sin = sin[:, None, None, :]
1107+
1108+
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
1109+
prefill_k_pe, prefill_k_nope = self.exec_kv_prefill(
1110+
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
1111+
attn_metadata.slot_mapping)
1112+
1113+
kv_c_normed = prefill_k_nope[:num_actual_toks, ...]
1114+
prefill_k_c_normed = prefill_k_nope[num_decode_tokens:]
10551115
prefill_k_pe = prefill_k_pe.view(num_tokens, self.num_kv_heads,
10561116
-1)
1057-
if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
1058-
# NOTE: When scaling not specified
1059-
ori_q_pe_shape, ori_k_pe_shape = prefill_q_pe.shape, prefill_k_pe.shape
1060-
prefill_q_pe = prefill_q_pe.reshape(num_tokens, -1)
1061-
prefill_k_pe = prefill_k_pe.reshape(num_tokens, -1)
1062-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
1063-
attn_metadata.prefill.input_positions, prefill_q_pe,
1064-
prefill_k_pe)
1065-
prefill_q_pe = prefill_q_pe.view(ori_q_pe_shape)
1066-
prefill_k_pe = prefill_k_pe.view(ori_k_pe_shape)
1067-
else:
1068-
prefill_q_pe, prefill_k_pe = self.rotary_emb(
1069-
attn_metadata.prefill.input_positions, prefill_q_pe,
1070-
prefill_k_pe)
10711117
prefill_q = torch.cat([prefill_q_nope, prefill_q_pe], dim=-1)
10721118
else:
10731119
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(

0 commit comments

Comments
 (0)