Skip to content

Commit 0f9e735

Browse files
[BugFix] Fix full-cuda-graph illegal memory access in FA3 (#20057)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent ba7ba35 commit 0f9e735

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158158

159159
self.aot_schedule = (get_flash_attn_version() == 3)
160160
self.use_full_cuda_graph = compilation_config.full_cuda_graph
161-
if self.use_full_cuda_graph and not self.aot_schedule:
162-
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
163-
"which requires FlashAttention 3.")
164-
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
165-
dtype=torch.int32,
166-
device=self.runner.device)
161+
if self.use_full_cuda_graph:
162+
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
163+
# yet. This is because the scheduler and kernel need to always use
164+
# the same num_splits (which acts as an upper bound with the
165+
# dynamic split scheduler) which is currently heuristically decided
166+
# by the kernel launching code.
167+
self.aot_schedule = False
167168

168169
# Sliding window size to be used with the AOT scheduler will be
169170
# populated on first build() call.
@@ -299,18 +300,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
299300
max_seq_len=max_seq_len,
300301
causal=True)
301302

302-
if self.use_full_cuda_graph:
303-
assert scheduler_metadata is not None
304-
n = scheduler_metadata.shape[0]
305-
self.scheduler_metadata[:n].copy_(scheduler_metadata,
306-
non_blocking=True)
307-
# NOTE(woosuk): We should zero out the rest of the scheduler
308-
# metadata to guarantee the correctness. Otherwise, some thread
309-
# blocks may use the invalid scheduler metadata and overwrite the
310-
# output buffer.
311-
self.scheduler_metadata[n:] = 0
312-
scheduler_metadata = self.scheduler_metadata[:n]
313-
314303
attn_metadata = FlashAttentionMetadata(
315304
num_actual_tokens=num_actual_tokens,
316305
max_query_len=max_query_len,

0 commit comments

Comments
 (0)