Skip to content

Commit a970b27

Browse files
Angazennangazenn
andauthored
[WIP][Perf]remove unnecessary padding before MLA V1 prefill (#917)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Currently, the implementation for MLA V1 pads q, k, v to `head_dim` 256 to conform to early MLA kernel. But the new MLA kernel supports `head_dim` that can't be devided by 128. Therefore we can remove those unnecessary paddings to boost the performance ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
1 parent dc6172e commit a970b27

File tree

1 file changed

+5
-25
lines changed

1 file changed

+5
-25
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,6 @@ def __init__(
373373
self.qk_rope_head_dim = qk_rope_head_dim
374374
self.qk_head_dim = qk_head_dim
375375
self.v_head_dim = v_head_dim
376-
# TODO: below padding should be removed after kernel is ready
377-
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
378-
# and slice the final result to guarantee its functionality.
379-
self.padding_head_dim = (
380-
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
381-
1) * 128
382376

383377
# Hack for V1 for now to avoid torch library overhead (since we are
384378
# already inside an attention custom op), pull out the forward
@@ -520,7 +514,7 @@ def _forward_prefill(
520514
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
521515
attn_output = torch.empty(num_tokens,
522516
self.num_heads,
523-
self.padding_head_dim,
517+
self.v_head_dim,
524518
dtype=query.dtype,
525519
device=query.device)
526520
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
@@ -529,31 +523,17 @@ def _forward_prefill(
529523
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
530524
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
531525
dim=-1)
532-
pad_query = torch.nn.functional.pad(query, [
533-
0, self.padding_head_dim - self.qk_rope_head_dim -
534-
self.qk_nope_head_dim
535-
],
536-
value=0)
537-
pad_key = torch.nn.functional.pad(key, [
538-
0, self.padding_head_dim - self.qk_rope_head_dim -
539-
self.qk_nope_head_dim
540-
],
541-
value=0)
542-
pad_value = torch.nn.functional.pad(
543-
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
544526
torch_npu._npu_flash_attention(
545-
query=pad_query,
546-
key=pad_key,
547-
value=pad_value,
527+
query=query,
528+
key=key,
529+
value=value,
548530
mask=attn_metadata.attn_mask,
549531
seq_len=attn_metadata.prefill.context_lens,
550532
scale_value=self.scale,
551533
num_heads=self.num_heads,
552534
num_kv_heads=self.num_heads,
553535
out=attn_output)
554-
attn_output = attn_output.view(
555-
-1, self.num_heads,
556-
self.padding_head_dim)[:, :, :self.v_head_dim]
536+
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
557537
else:
558538
raise RuntimeError(
559539
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"

0 commit comments

Comments
 (0)