-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import abc | ||
import enum | ||
import functools | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
|
@@ -62,10 +63,23 @@ class CommonAttentionMetadata: | |
|
||
M = TypeVar("M") | ||
|
||
class AttentionCGSupport(enum.Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: How do you feel about changing the name to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059. |
||
""" Constants for the cudagraph support of the attention backend | ||
Here we do not consider the cascade attention, as currently | ||
it is never cudagraph supported.""" | ||
|
||
NEVER = 0 | ||
"""NO cudagraph support""" | ||
PURE_DECODE_ONLY = 1 | ||
"""Cudagraph supported for pure decode, need to use piecewise | ||
cudagraph or no cudagraph for mixed prefill-decode batches""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the mixed prefill-decode case: when will piecewise cudagraph be used and when will no cuda graph be used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am sorry, this comment is directly from #20059, I didn't realize there are no piecewise cudagraph exists in this PR when enabling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's fixed now |
||
ALWAYS = 2 | ||
"""Cudagraph always supported""" | ||
|
||
class AttentionMetadataBuilder(abc.ABC, Generic[M]): | ||
# Does this backend/builder support CUDA Graphs for attention. | ||
full_cudagraph_supported: ClassVar[bool] = False | ||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ | ||
AttentionCGSupport.NEVER | ||
|
||
@abstractmethod | ||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -47,7 +47,7 @@ | |||||||||||||||||||
is_pin_memory_available, round_up) | ||||||||||||||||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend | ||||||||||||||||||||
from vllm.v1.attention.backends.utils import ( | ||||||||||||||||||||
AttentionMetadataBuilder, CommonAttentionMetadata, | ||||||||||||||||||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, | ||||||||||||||||||||
make_local_attention_virtual_batches) | ||||||||||||||||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||||||||||||||||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, | ||||||||||||||||||||
|
@@ -2369,6 +2369,15 @@ | |||||||||||||||||||
# can reuse the memory pool allocated for the large shapes. | ||||||||||||||||||||
with graph_capture(device=self.device): | ||||||||||||||||||||
full_cg = self.full_cuda_graph | ||||||||||||||||||||
# for full cg on pure decode only, do not capture size lager than | ||||||||||||||||||||
# max_num_seqs | ||||||||||||||||||||
if full_cg and self.attn_metadata_builders[0].attn_cudagraph_support\ | ||||||||||||||||||||
== AttentionCGSupport.PURE_DECODE_ONLY: | ||||||||||||||||||||
max_num_seqs = self.scheduler_config.max_num_seqs | ||||||||||||||||||||
self.cudagraph_batch_sizes = [ | ||||||||||||||||||||
size for size in self.cudagraph_batch_sizes | ||||||||||||||||||||
if size <= max_num_seqs] | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change correctly filters This discrepancy can lead to a To fix this, you should re-initialize the padding map after filtering
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find that the following hangs:
at this point:
It just hangs here - could this be related? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if I remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find overriding gpu model runner's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm unfortunately it doesn't happen on main + FlashAttention. The hang is 100% reproducible using the code from this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try my best to figure it out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have tested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tdoublep Take the new fix! It should be fine now. Could you please also test if it works for you? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, seems to work now. Thanks! |
||||||||||||||||||||
|
||||||||||||||||||||
fhl2000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
# Only rank 0 should print progress bar during capture | ||||||||||||||||||||
compilation_cases = reversed(self.cudagraph_batch_sizes) | ||||||||||||||||||||
if is_global_first_rank(): | ||||||||||||||||||||
|
@@ -2438,7 +2447,8 @@ | |||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
if (self.full_cuda_graph | ||||||||||||||||||||
and not attn_metadata_builder_i.full_cudagraph_supported): | ||||||||||||||||||||
and attn_metadata_builder_i.attn_cudagraph_support == \ | ||||||||||||||||||||
AttentionCGSupport.NEVER): | ||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||
f"Full CUDAGraph not supported for " | ||||||||||||||||||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig." | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.