Skip to content

Commit df18f1d

Browse files
authored
[0.9.1][2/N][Feat] Restore paged attention kernel in Full Graph for performence (#1677)
### What this PR does / why we need it? Rectified the performance regression wherein the FIA kernel underperformed the PA kernel by enabling dynamic updates of PA parameters during graph replay. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 9ca9c6f commit df18f1d

File tree

3 files changed

+67
-99
lines changed

3 files changed

+67
-99
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
3636
reason="aclgraph only support on v1")
3737
@pytest.mark.parametrize("model", MODELS)
38-
@pytest.mark.parametrize("max_tokens", [32])
39-
@pytest.mark.parametrize("full_graph", [False])
38+
@pytest.mark.parametrize("max_tokens", [12])
39+
@pytest.mark.parametrize("full_graph", [True, False])
4040
def test_models(
4141
model: str,
4242
max_tokens: int,

vllm_ascend/attention/attention_v1.py

Lines changed: 43 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class AscendMetadata:
118118
query_start_loc: torch.Tensor
119119
query_lens: torch.Tensor
120120
seq_lens: torch.Tensor
121-
seq_lens_list: list
121+
seq_lens_list: Optional[list[int]]
122122
# Maximum query length in the batch. None for decoding.
123123
max_query_len: Optional[int] = None
124124
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -168,8 +168,9 @@ def build(self,
168168
seq_lens = common_attn_metadata.seq_lens
169169
# TODO: Refactor these two param to common metadata in runners,
170170
# preparing for the hybrid KV groups feature
171-
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
172-
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list
171+
query_lens = common_attn_metadata.query_lens or self.runner.query_lens
172+
# Since FIA for GQA is not active now, we temporarily silence it
173+
seq_lens_list = common_attn_metadata.seq_lens_list
173174

174175
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
175176
attn_mask = self.runner.attn_mask
@@ -193,8 +194,8 @@ def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193194
num_scheduled_tokens, attn_state):
194195
if attn_state == AscendAttentionState.DecodeOnly:
195196
# NOTE: We only need to pay attention to seq_lens_list and block_table here
196-
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
197-
num_reqs)
197+
common_attn_metadata = CommonAttentionMetadata(
198+
seq_lens=torch.empty_like(self.runner.seq_lens_cpu).fill_(2))
198199

199200
block_table = self.runner.input_batch.block_table[0].block_table
200201
block_table[:num_reqs, 0] = torch.arange(1,
@@ -349,82 +350,42 @@ def forward(
349350
scale_value=self.scale,
350351
out=output)
351352
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
352-
if self.full_graph:
353-
graph_params = get_graph_params()
354-
q = query.view(num_tokens, -1, self.hidden_size)
355-
k = self.key_cache.view( # type: ignore
356-
-1, self.block_size,
357-
self.num_kv_heads * self.head_size)
358-
v = self.value_cache.view( # type: ignore
359-
-1, self.block_size,
360-
self.num_kv_heads * self.head_size)
361-
actual_seq_lens = attn_metadata.seq_lens_list
362-
attn_args = {
363-
"query": q,
364-
"key": k,
365-
"value": v,
366-
"actual_seq_lengths_kv": actual_seq_lens,
367-
"block_table": attn_metadata.block_tables,
368-
"num_heads": self.num_heads,
369-
"scale": self.scale,
370-
"input_layout": "BSH",
371-
"num_key_value_heads": self.num_kv_heads,
372-
"block_size": self.block_size,
373-
}
374-
375-
# Prepare tensors for attention output
376-
# TODO: Refactor this to step-level instead of layer-level
377-
attn_output = torch.empty(num_tokens,
378-
1,
379-
self.hidden_size,
380-
dtype=output.dtype,
381-
device=output.device)
382-
softmax_lse = torch.empty(num_tokens,
383-
dtype=output.dtype,
384-
device=output.device)
385-
386-
# Get workspace from cache or calculate it if not present.
387-
workspace = graph_params.workspaces.get(num_tokens)
388-
if workspace is None:
389-
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
390-
**attn_args)
391-
graph_params.workspaces[num_tokens] = workspace
392-
393-
forward_context = get_forward_context()
394-
if not forward_context.capturing:
395-
# Execute attention kernel directly in non-capturing mode
396-
torch.ops.npu.npu_fused_infer_attention_score.out(
397-
workspace=workspace,
398-
out=[attn_output, softmax_lse],
399-
**attn_args)
400-
else:
401-
# Handle graph capturing mode
402-
stream = torch_npu.npu.current_stream()
403-
404-
event = torch.npu.ExternalEvent()
405-
event.wait(stream)
406-
event.reset(stream)
407-
graph_params.events[num_tokens].append(event)
408-
409-
graph_params.attn_params[num_tokens].append(
410-
(q, k, v, actual_seq_lens,
411-
attn_metadata.block_tables, self.num_heads,
412-
self.scale, self.num_kv_heads, attn_output,
413-
softmax_lse))
414-
415-
torch.npu.graph_task_group_begin(stream)
416-
torch.ops.npu.npu_fused_infer_attention_score.out(
417-
workspace=workspace,
418-
out=[attn_output, softmax_lse],
419-
**attn_args)
420-
handle = torch.npu.graph_task_group_end(stream)
421-
graph_params.handles[num_tokens].append(handle)
422-
423-
# Reshape output to match the expected format
424-
output.copy_(
425-
attn_output.view(num_tokens, self.num_heads,
426-
self.head_size))
353+
graph_params = get_graph_params()
354+
355+
forward_context = get_forward_context()
356+
if not forward_context.capturing:
357+
torch_npu._npu_paged_attention(
358+
query=query,
359+
key_cache=self.key_cache,
360+
value_cache=self.value_cache,
361+
num_kv_heads=self.num_kv_heads,
362+
num_heads=self.num_heads,
363+
scale_value=self.scale,
364+
block_table=attn_metadata.block_tables,
365+
context_lens=attn_metadata.seq_lens,
366+
out=output)
427367
else:
368+
# Handle graph capturing mode
369+
stream = torch_npu.npu.current_stream()
370+
371+
event = torch.npu.ExternalEvent()
372+
event.wait(stream)
373+
event.reset(stream)
374+
graph_params.events[num_tokens].append(event)
375+
376+
graph_params.attn_params[num_tokens].append((
377+
query,
378+
self.key_cache,
379+
self.value_cache,
380+
self.num_kv_heads,
381+
self.num_heads,
382+
self.scale,
383+
attn_metadata.block_tables,
384+
attn_metadata.seq_lens,
385+
output,
386+
))
387+
388+
torch.npu.graph_task_group_begin(stream)
428389
torch_npu._npu_paged_attention(
429390
query=query,
430391
key_cache=self.key_cache,
@@ -435,6 +396,8 @@ def forward(
435396
block_table=attn_metadata.block_tables,
436397
context_lens=attn_metadata.seq_lens,
437398
out=output)
399+
handle = torch.npu.graph_task_group_end(stream)
400+
graph_params.handles[num_tokens].append(handle)
438401
# Normal V1 situation.
439402
else:
440403
# use chunked prefill for head size 192 scenario, like deepseek

vllm_ascend/compilation/piecewise_backend.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import torch
2525
import torch.fx as fx
26+
import torch_npu
2627
import vllm.envs as envs
2728
from vllm.compilation.backends import VllmBackend
2829
from vllm.compilation.counter import compilation_counter
@@ -126,29 +127,33 @@ def check_for_ending_compilation(self):
126127

127128
def update_attn_params(self, graph_params, forward_context, runtime_shape):
128129
for layer_idx in range(len(graph_params.handles[runtime_shape])):
129-
query, key, value, actual_seq_lens, block_table, num_heads, scale, num_kv_heads, output, softmax_lse = graph_params.attn_params[
130-
runtime_shape][layer_idx]
130+
(
131+
query,
132+
key_cache,
133+
value_cache,
134+
num_kv_heads,
135+
num_heads,
136+
scale,
137+
block_table,
138+
seq_lens,
139+
output,
140+
) = graph_params.attn_params[runtime_shape][layer_idx]
131141
block_table = forward_context.attn_metadata.block_tables
132-
actual_seq_lens = forward_context.attn_metadata.seq_lens_list
142+
seq_lens = forward_context.attn_metadata.seq_lens
133143

134144
with torch.npu.stream(self.update_stream):
135145
torch.npu.graph_task_update_begin(
136146
self.update_stream,
137147
graph_params.handles[runtime_shape][layer_idx])
138-
torch.ops.npu.npu_fused_infer_attention_score.out(
139-
query,
140-
key,
141-
value,
142-
workspace=graph_params.workspaces[runtime_shape],
143-
actual_seq_lengths_kv=actual_seq_lens,
144-
block_table=block_table,
145-
num_heads=num_heads,
146-
scale=scale,
147-
input_layout="BSH",
148-
num_key_value_heads=num_kv_heads,
149-
block_size=128,
150-
out=[output, softmax_lse],
151-
)
148+
torch_npu._npu_paged_attention(query=query,
149+
key_cache=key_cache,
150+
value_cache=value_cache,
151+
num_kv_heads=num_kv_heads,
152+
num_heads=num_heads,
153+
scale_value=scale,
154+
block_table=block_table,
155+
context_lens=seq_lens,
156+
out=output)
152157
torch.npu.graph_task_update_end(self.update_stream)
153158

154159
graph_params.events[runtime_shape][layer_idx].record(

0 commit comments

Comments
 (0)