Skip to content

Commit 0329fad

Browse files
authored
[Perf] Deepseekv3 performance optimization for eager mode (#598)
### What this PR does / why we need it? Deepseek v3 now adopt vanilla chunked prefill on MLA part which is ineffcient for computing but necessary for chunked prefill. Since PR #543 bring v0 scheduler into vllm-ascend, we can now adopt torch_npu._npu_flash_attention inside the mla backend for more performance boost. Also there are some redundant computation inside the rope, which is also removed. This PR should bring some performance gain for deepseek eager mode inference. --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent 87975fa commit 0329fad

File tree

4 files changed

+178
-100
lines changed

4 files changed

+178
-100
lines changed

tests/ops/test_rotary_embedding.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ def forward_native(
136136

137137

138138
# test with leading dimension and merge seqlen and batch_size as num_tokens
139-
# TODO(ganyi): open this test in the future
140-
@pytest.mark.skip(
141-
reason=
142-
"skip this test by default for now because of ci issue, will enable it in the future"
143-
)
144139
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
145140
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
146141
@pytest.mark.parametrize("seq_len", SEQ_LENS)

vllm_ascend/attention/mla_v1.py

Lines changed: 90 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata:
5555
input_positions: torch.Tensor
5656
block_table: torch.Tensor
5757
max_query_len: int
58-
max_context_len: int
58+
max_seq_lens: int
5959

6060

6161
@dataclass
@@ -65,6 +65,7 @@ class AscendMLADecodeMetadata:
6565
input_positions: torch.Tensor
6666
block_table: torch.Tensor
6767
seq_lens: torch.Tensor
68+
max_seq_lens: int
6869

6970

7071
@dataclass
@@ -131,11 +132,6 @@ def __init__(self,
131132
self.runner = runner
132133
scheduler_config = runner.scheduler_config
133134
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
134-
# self.attn_mask = None
135-
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
136-
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
137-
# 128, self.runner.model_config.dtype
138-
# )
139135

140136
def reorder_batch(self, input_batch: "InputBatch",
141137
scheduler_output: "SchedulerOutput") -> bool:
@@ -222,12 +218,14 @@ def build(self,
222218
num_reqs]
223219
seq_lens = seq_lens_cpu
224220
max_query_len = query_lens.max().item()
225-
max_context_len = seq_lens.max().item()
221+
max_seq_lens = seq_lens.max().item()
226222

227223
prefill_metadata = None
228224
if self._num_prefills > 0:
229225
reqs_start = self._num_decodes # prefill_start
230226
tokens_start = self._num_decode_tokens
227+
max_query_len = query_lens[tokens_start:].max().item()
228+
max_seq_lens = seq_lens[tokens_start:].max().item()
231229

232230
prefill_metadata = AscendMLAPrefillMetadata(
233231
attn_mask=self.runner.attn_mask,
@@ -236,15 +234,17 @@ def build(self,
236234
input_positions=input_positions[tokens_start:],
237235
block_table=block_table[reqs_start:, ...],
238236
max_query_len=max_query_len,
239-
max_context_len=max_context_len,
237+
max_seq_lens=max_seq_lens,
240238
)
241239

242240
decode_metadata = None
243241
if self._num_decodes > 0:
242+
max_seq_lens = seq_lens[:self._num_decodes].max().item()
244243
decode_metadata = AscendMLADecodeMetadata(
245244
input_positions=input_positions[:self._num_decode_tokens],
246245
block_table=block_table[:self._num_decode_tokens, ...],
247-
seq_lens=seq_lens[:self._num_decode_tokens])
246+
seq_lens=seq_lens[:self._num_decode_tokens],
247+
max_seq_lens=max_seq_lens)
248248

249249
return self.metadata_cls( # type: ignore
250250
num_actual_tokens=num_actual_tokens,
@@ -306,12 +306,18 @@ def __init__(
306306
self.qk_rope_head_dim = qk_rope_head_dim
307307
self.qk_head_dim = qk_head_dim
308308
self.v_head_dim = v_head_dim
309+
# TODO: below padding should be removed after kernel is ready
310+
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
311+
# and slice the final result to guarantee its functionality.
312+
self.padding_head_dim = (
313+
(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
314+
1) * 128
309315

310316
# Hack for V1 for now to avoid torch library overhead (since we are
311317
# already inside an attention custom op), pull out the forward
312318
# method from the rotary embedding and call it directly
313319
# TODO(lucas): we should probably find a cleaner way to do this
314-
self.rotary_emb = rotary_emb.forward_native
320+
self.rotary_emb = rotary_emb
315321

316322
self.q_proj = q_proj
317323
self.kv_b_proj = kv_b_proj
@@ -409,37 +415,73 @@ def _forward_prefill(
409415
) -> torch.Tensor:
410416
assert attn_metadata.prefill is not None
411417

412-
# TODO: enable this compute for flash attention computation
413-
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
414-
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
415-
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
416-
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
417-
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
418-
# value=0)
419418
num_tokens = query.size(0)
420-
attn_output = torch.empty(num_tokens,
421-
self.num_heads,
422-
self.v_head_dim,
423-
dtype=query.dtype,
424-
device=query.device)
425-
# current requests is chunked in prefill, disable flash attention with chunked prefill
426-
vanilla_chunked_prefill_mla(
427-
output=attn_output,
428-
query=query,
429-
kv_cache=kv_c_and_k_pe_cache,
430-
block_tables=attn_metadata.prefill.block_table,
431-
query_lens=attn_metadata.prefill.query_lens,
432-
context_lens=attn_metadata.prefill.context_lens,
433-
kv_b_proj=self.kv_b_proj,
434-
max_query_len=attn_metadata.prefill.max_query_len,
435-
max_context_len=attn_metadata.prefill.max_context_len,
436-
nope_dim=self.qk_nope_head_dim,
437-
rope_dim=self.qk_rope_head_dim,
438-
v_head_dim=self.v_head_dim,
439-
scale=self.scale,
440-
alibi_slopes=None,
441-
causal=True)
442-
attn_output = attn_output.view(
419+
attn_output = None
420+
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
421+
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
422+
attn_output = torch.empty(num_tokens,
423+
self.num_heads * self.v_head_dim,
424+
dtype=query.dtype,
425+
device=query.device)
426+
# current requests is chunked in prefill, disable flash attention with chunked prefill
427+
vanilla_chunked_prefill_mla(
428+
output=attn_output,
429+
query=query,
430+
kv_cache=kv_c_and_k_pe_cache,
431+
block_tables=attn_metadata.prefill.block_table,
432+
query_lens=attn_metadata.prefill.query_lens,
433+
context_lens=attn_metadata.prefill.context_lens,
434+
kv_b_proj=self.kv_b_proj,
435+
max_query_len=attn_metadata.prefill.max_query_len,
436+
max_context_len=attn_metadata.prefill.max_seq_lens,
437+
nope_dim=self.qk_nope_head_dim,
438+
rope_dim=self.qk_rope_head_dim,
439+
v_head_dim=self.v_head_dim,
440+
scale=self.scale,
441+
alibi_slopes=None,
442+
causal=True)
443+
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
444+
attn_output = torch.empty(num_tokens,
445+
self.num_heads,
446+
self.padding_head_dim,
447+
dtype=query.dtype,
448+
device=query.device)
449+
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
450+
-1, self.num_heads,
451+
self.qk_nope_head_dim + self.v_head_dim).split(
452+
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
453+
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
454+
dim=-1)
455+
pad_query = torch.nn.functional.pad(query, [
456+
0, self.padding_head_dim - self.qk_rope_head_dim -
457+
self.qk_nope_head_dim
458+
],
459+
value=0)
460+
pad_key = torch.nn.functional.pad(key, [
461+
0, self.padding_head_dim - self.qk_rope_head_dim -
462+
self.qk_nope_head_dim
463+
],
464+
value=0)
465+
pad_value = torch.nn.functional.pad(
466+
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
467+
torch_npu._npu_flash_attention(
468+
query=pad_query,
469+
key=pad_key,
470+
value=pad_value,
471+
mask=attn_metadata.attn_mask,
472+
seq_len=attn_metadata.prefill.context_lens,
473+
scale_value=self.scale,
474+
num_heads=self.num_heads,
475+
num_kv_heads=self.num_heads,
476+
out=attn_output)
477+
attn_output = attn_output.view(
478+
-1, self.num_heads,
479+
self.padding_head_dim)[:, :, :self.v_head_dim]
480+
else:
481+
raise RuntimeError(
482+
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
483+
)
484+
attn_output = attn_output.reshape(
443485
[num_tokens, self.num_heads * self.v_head_dim])
444486
return self.o_proj(attn_output)[0]
445487

@@ -457,7 +499,7 @@ def _forward_decode(
457499

458500
q = torch.cat([q_nope, q_pe], dim=-1)
459501
num_tokens = q.size(0)
460-
attn_output = torch.randn(
502+
attn_output = torch.empty(
461503
[num_tokens, self.num_heads, self.kv_lora_rank],
462504
dtype=q.dtype,
463505
device=q.device)
@@ -522,8 +564,10 @@ def forward(
522564
decode_ql_nope, decode_q_pe = \
523565
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
524566
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
525-
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
526-
decode_k_pe)
567+
attn_metadata.decode.input_positions,
568+
decode_q_pe.contiguous(),
569+
decode_k_pe,
570+
max_seq_len=attn_metadata.decode.max_seq_lens)
527571

528572
if has_prefill:
529573
assert attn_metadata.prefill is not None
@@ -533,7 +577,9 @@ def forward(
533577

534578
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
535579
attn_metadata.prefill.input_positions,
536-
prefill_q_pe.contiguous(), prefill_k_pe)
580+
prefill_q_pe.contiguous(),
581+
prefill_k_pe,
582+
max_seq_len=attn_metadata.prefill.max_seq_lens)
537583

538584
if kv_cache.numel() > 0:
539585
key = torch.cat([

0 commit comments

Comments
 (0)