Skip to content

Commit 04e6169

Browse files
authored
[Feat] Implement primal full graph with limited scenario (#1503)
### What this PR does / why we need it? This pull request introduces full-graph capture, replacing the previous piecewise-graph approach. Key improvements include: * **Reduced dispatch latency:** By capturing the entire model execution graph at once, we minimize overhead compared to multiple smaller captures. * **Stabilized multi-GPU performance:** Eliminates throughput fluctuations during the `MODEL_EXECUTE` phase across multiple cards. * **Stream resource savings:** Consolidating graph captures frees up streams, allowing more graphs to be captured concurrently. **Known issues:** 1. Capturing larger or more numerous graphs increases GPU memory usage, which can lead to OOM errors or inference hangs. 2. The new paged-attention implementation relies on the FIA operator, which in certain workloads is slower than the previous approach—resulting in a regression in end-to-end throughput. There may be other undiscovered corner cases. This PR is the first in a planned series; we will continue to iterate on and address any remaining issues in subsequent submissions. ### Does this PR introduce _any_ user-facing change? ```python compilation_config={ "full_cuda_graph": True, }, ``` ### How was this patch tested? --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent e1d282d commit 04e6169

File tree

10 files changed

+316
-60
lines changed

10 files changed

+316
-60
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
reason="aclgraph only support on v1")
3737
@pytest.mark.parametrize("model", MODELS)
3838
@pytest.mark.parametrize("max_tokens", [32])
39+
@pytest.mark.parametrize("full_graph", [False])
3940
def test_models(
4041
model: str,
4142
max_tokens: int,
43+
full_graph: bool,
4244
monkeypatch: pytest.MonkeyPatch,
4345
) -> None:
4446
with monkeypatch.context() as m:
@@ -54,7 +56,15 @@ def test_models(
5456
temperature=0.0)
5557
# TODO: change to use vllmrunner when the registry of custom op is solved
5658
# while running pytest
57-
vllm_model = LLM(model)
59+
if full_graph:
60+
vllm_model = LLM(model,
61+
compilation_config={
62+
"full_cuda_graph": True,
63+
"cudagraph_capture_sizes":
64+
[1, 4, 16, 64, 256]
65+
})
66+
else:
67+
vllm_model = LLM(model)
5868
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
5969
del vllm_model
6070
torch.npu.empty_cache()

vllm_ascend/ascend_forward_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def set_ascend_forward_context(
5555

5656
forward_context.in_profile_run = in_profile_run
5757

58+
# NOTE: This cannot be set using set_forward_context
59+
# due to multiple warmups before actual capturing
60+
forward_context.capturing = False
61+
5862
dp_world_size = get_dp_group().world_size
5963
if dp_world_size > 1 and forward_context.dp_metadata is not None:
6064
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(

vllm_ascend/attention/attention_v1.py

Lines changed: 137 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,16 @@
2424
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2525
AttentionLayer, AttentionType)
2626
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.config import get_current_vllm_config
2728
from vllm.forward_context import ForwardContext, get_forward_context
2829
from vllm.utils import direct_register_custom_op
2930
from vllm.v1.core.sched.output import SchedulerOutput
3031
from vllm.v1.worker.gpu_input_batch import InputBatch
3132

33+
from vllm_ascend.attention.utils import \
34+
AscendCommonAttentionMetadata as CommonAttentionMetadata
3235
from vllm_ascend.ops.attention import vanilla_chunked_prefill
36+
from vllm_ascend.utils import get_graph_params
3337

3438

3539
class AscendAttentionBackend(AttentionBackend):
@@ -114,6 +118,7 @@ class AscendMetadata:
114118
query_start_loc: torch.Tensor
115119
query_lens: torch.Tensor
116120
seq_lens: torch.Tensor
121+
seq_lens_list: list
117122
# Maximum query length in the batch. None for decoding.
118123
max_query_len: Optional[int] = None
119124
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -149,37 +154,69 @@ def build(self,
149154
num_reqs,
150155
num_actual_tokens,
151156
max_query_len,
152-
common_prefix_len,
153-
enable_dbo_across_dp: bool = False):
157+
common_attn_metadata: CommonAttentionMetadata,
158+
enable_dbo_across_dp: bool = False,
159+
*args,
160+
**kwargs):
154161

155162
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
156163
)
157164
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
158165
block_table[:num_reqs])
159166

160-
query_lens = self.runner.query_lens
161-
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
162-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
163-
self.runner.device, non_blocking=True)
167+
query_start_loc = common_attn_metadata.query_start_loc
168+
seq_lens = common_attn_metadata.seq_lens
169+
# TODO: Refactor these two param to common metadata in runners,
170+
# preparing for the hybrid KV groups feature
171+
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
172+
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list
173+
174+
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
164175
attn_mask = self.runner.attn_mask
165176
attn_state = self.runner.attn_state
166-
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
167-
query_start_loc = query_start_loc_cpu.to(self.runner.device,
168-
non_blocking=True)
169177

170178
attn_metadata = AscendMetadata(
171179
num_actual_tokens=num_actual_tokens,
172180
block_tables=block_table,
173181
query_start_loc=query_start_loc,
174182
query_lens=query_lens,
175183
seq_lens=seq_lens,
184+
seq_lens_list=seq_lens_list,
176185
max_query_len=max_query_len,
177186
slot_mapping=slot_mapping,
178187
attn_mask=attn_mask,
179188
attn_state=attn_state,
180189
enable_dbo_across_dp=enable_dbo_across_dp)
181190
return attn_metadata
182191

192+
def build_dummy_metadata(self, num_actual_tokens, num_reqs,
193+
num_scheduled_tokens, attn_state):
194+
if attn_state == AscendAttentionState.DecodeOnly:
195+
# NOTE: We only need to pay attention to seq_lens_list and block_table here
196+
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
197+
num_reqs)
198+
199+
block_table = self.runner.input_batch.block_table[0].block_table
200+
block_table[:num_reqs, 0] = torch.arange(1,
201+
num_reqs + 1,
202+
device=block_table.device,
203+
dtype=block_table.dtype)
204+
205+
attn_metadata = self.build(
206+
num_reqs=num_reqs,
207+
num_actual_tokens=num_actual_tokens,
208+
max_query_len=num_scheduled_tokens.max(),
209+
common_prefix_len=0,
210+
common_attn_metadata=common_attn_metadata,
211+
)
212+
else:
213+
raise NotImplementedError(
214+
"Currently we only support building dummy metadata for DecodeOnly state"
215+
)
216+
217+
attn_metadata.attn_state = attn_state
218+
return attn_metadata
219+
183220

184221
class AscendAttentionBackendImpl(AttentionImpl):
185222

@@ -217,6 +254,10 @@ def __init__(
217254
self.key_cache = None
218255
self.value_cache = None
219256

257+
vllm_config = get_current_vllm_config()
258+
self.full_graph = vllm_config.compilation_config.full_cuda_graph
259+
self.block_size = vllm_config.cache_config.block_size
260+
220261
def forward(
221262
self,
222263
layer: AttentionLayer,
@@ -228,21 +269,7 @@ def forward(
228269
output: Optional[torch.Tensor] = None,
229270
trace_flag: bool = True,
230271
) -> torch.Tensor:
231-
"""Forward pass with Ascend attention.
232-
Args:
233-
query: shape = [batch_size, seq_len, num_heads * head_size]
234-
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
235-
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
236-
kv_cache: shape = [2, num_blocks, block_size,
237-
num_kv_heads, head_size]
238-
key_cache = [num_blocks, block_size,
239-
num_kv_heads, head_size]
240-
value_cache = [num_blocks, block_size,
241-
num_kv_heads, head_size]
242-
attn_metadata: Metadata for attention.
243-
Returns:
244-
shape = [batch_size * seq_len, num_heads, head_size]
245-
"""
272+
"""Forward pass with Ascend attention."""
246273
num_tokens = query.shape[0]
247274
if output is None:
248275
output = torch.empty(num_tokens,
@@ -322,16 +349,92 @@ def forward(
322349
scale_value=self.scale,
323350
out=output)
324351
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
325-
torch_npu._npu_paged_attention(
326-
query=query,
327-
key_cache=self.key_cache,
328-
value_cache=self.value_cache,
329-
num_kv_heads=self.num_kv_heads,
330-
num_heads=self.num_heads,
331-
scale_value=self.scale,
332-
block_table=attn_metadata.block_tables,
333-
context_lens=attn_metadata.seq_lens,
334-
out=output)
352+
if self.full_graph:
353+
graph_params = get_graph_params()
354+
q = query.view(num_tokens, -1, self.hidden_size)
355+
k = self.key_cache.view( # type: ignore
356+
-1, self.block_size,
357+
self.num_kv_heads * self.head_size)
358+
v = self.value_cache.view( # type: ignore
359+
-1, self.block_size,
360+
self.num_kv_heads * self.head_size)
361+
actual_seq_lens = attn_metadata.seq_lens_list
362+
attn_args = {
363+
"query": q,
364+
"key": k,
365+
"value": v,
366+
"actual_seq_lengths_kv": actual_seq_lens,
367+
"block_table": attn_metadata.block_tables,
368+
"num_heads": self.num_heads,
369+
"scale": self.scale,
370+
"input_layout": "BSH",
371+
"num_key_value_heads": self.num_kv_heads,
372+
"block_size": self.block_size,
373+
}
374+
375+
# Prepare tensors for attention output
376+
# TODO: Refactor this to step-level instead of layer-level
377+
attn_output = torch.empty(num_tokens,
378+
1,
379+
self.hidden_size,
380+
dtype=output.dtype,
381+
device=output.device)
382+
softmax_lse = torch.empty(num_tokens,
383+
dtype=output.dtype,
384+
device=output.device)
385+
386+
# Get workspace from cache or calculate it if not present.
387+
workspace = graph_params.workspaces.get(num_tokens)
388+
if workspace is None:
389+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
390+
**attn_args)
391+
graph_params.workspaces[num_tokens] = workspace
392+
393+
forward_context = get_forward_context()
394+
if not forward_context.capturing:
395+
# Execute attention kernel directly in non-capturing mode
396+
torch.ops.npu.npu_fused_infer_attention_score.out(
397+
workspace=workspace,
398+
out=[attn_output, softmax_lse],
399+
**attn_args)
400+
else:
401+
# Handle graph capturing mode
402+
stream = torch_npu.npu.current_stream()
403+
404+
event = torch.npu.ExternalEvent()
405+
event.wait(stream)
406+
event.reset(stream)
407+
graph_params.events[num_tokens].append(event)
408+
409+
graph_params.attn_params[num_tokens].append(
410+
(q, k, v, actual_seq_lens,
411+
attn_metadata.block_tables, self.num_heads,
412+
self.scale, self.num_kv_heads, attn_output,
413+
softmax_lse))
414+
415+
torch.npu.graph_task_group_begin(stream)
416+
torch.ops.npu.npu_fused_infer_attention_score.out(
417+
workspace=workspace,
418+
out=[attn_output, softmax_lse],
419+
**attn_args)
420+
handle = torch.npu.graph_task_group_end(stream)
421+
graph_params.handles[num_tokens].append(handle)
422+
423+
# Reshape output to match the expected format
424+
output.copy_(
425+
attn_output.view(num_tokens, self.num_heads,
426+
self.head_size))
427+
else:
428+
torch_npu._npu_paged_attention(
429+
query=query,
430+
key_cache=self.key_cache,
431+
value_cache=self.value_cache,
432+
num_kv_heads=self.num_kv_heads,
433+
num_heads=self.num_heads,
434+
scale_value=self.scale,
435+
block_table=attn_metadata.block_tables,
436+
context_lens=attn_metadata.seq_lens,
437+
out=output)
335438
# Normal V1 situation.
336439
else:
337440
# use chunked prefill for head size 192 scenario, like deepseek

vllm_ascend/attention/mla_v1.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm_ascend import envs
1717
from vllm_ascend.ascend_config import get_ascend_config
1818
from vllm_ascend.attention.attention_v1 import AscendAttentionState
19+
from vllm_ascend.attention.utils import \
20+
AscendCommonAttentionMetadata as CommonAttentionMetadata
1921
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
2022
from vllm_ascend.multistream.context import get_multistream_comm_context
2123
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
@@ -28,20 +30,6 @@
2830
from vllm.v1.worker.gpu_input_batch import InputBatch
2931

3032

31-
@dataclass
32-
class CommonAttentionMetadata:
33-
"""
34-
Attention metadata attributes that can be shared by layers in different KV
35-
cache groups and thus having different block table.
36-
"""
37-
38-
query_start_loc: torch.Tensor
39-
"""(batch_size + 1,), the start location of each request in query Tensor"""
40-
seq_lens: torch.Tensor
41-
"""(batch_size,), the length of each request including both computed tokens
42-
and newly scheduled tokens"""
43-
44-
4533
class AscendMLABackend(AttentionBackend):
4634

4735
accept_output_buffer: bool = True

vllm_ascend/attention/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class AscendCommonAttentionMetadata:
9+
"""
10+
Attention metadata attributes that can be shared by layers in different KV
11+
cache groups and thus having different block table.
12+
"""
13+
14+
query_start_loc: torch.Tensor = None
15+
"""(batch_size + 1,), the start location of each request in query Tensor"""
16+
seq_lens: Optional[torch.Tensor] = None
17+
"""(batch_size,), the length of each request including both computed tokens
18+
and newly scheduled tokens"""
19+
query_lens: Optional[torch.Tensor] = None
20+
"""(batch_size,), the length of each request including only the newly
21+
scheduled tokens"""
22+
seq_lens_list: Optional[list] = None
23+
"""(num_input_tokens,), note that this is specifically for FIA kernel"""

0 commit comments

Comments
 (0)