-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import abc | ||
import enum | ||
import functools | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
|
@@ -63,9 +64,24 @@ class CommonAttentionMetadata: | |
M = TypeVar("M") | ||
|
||
|
||
class AttentionCGSupport(enum.Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: How do you feel about changing the name to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059. |
||
""" Constants for the cudagraph support of the attention backend | ||
Here we do not consider the cascade attention, as currently | ||
it is never cudagraph supported.""" | ||
|
||
NEVER = 0 | ||
"""NO cudagraph support""" | ||
PURE_DECODE_ONLY = 1 | ||
"""Cudagraph supported for pure decode, need to use piecewise | ||
cudagraph or no cudagraph for mixed prefill-decode batches""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the mixed prefill-decode case: when will piecewise cudagraph be used and when will no cuda graph be used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am sorry, this comment is directly from #20059, I didn't realize there are no piecewise cudagraph exists in this PR when enabling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's fixed now |
||
ALWAYS = 2 | ||
"""Cudagraph always supported""" | ||
|
||
|
||
class AttentionMetadataBuilder(abc.ABC, Generic[M]): | ||
# Does this backend/builder support CUDA Graphs for attention. | ||
full_cudagraph_supported: ClassVar[bool] = False | ||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ | ||
AttentionCGSupport.NEVER | ||
|
||
@abstractmethod | ||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.