-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[V1][CUDA] Full cudagraph support for FlashInfer #21367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import abc | ||
import enum | ||
import functools | ||
from abc import abstractmethod | ||
from dataclasses import dataclass | ||
|
@@ -63,9 +64,24 @@ class CommonAttentionMetadata: | |
M = TypeVar("M") | ||
|
||
|
||
class AttentionCGSupport(enum.Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: How do you feel about changing the name to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059. |
||
""" Constants for the cudagraph support of the attention backend | ||
Here we do not consider the cascade attention, as currently | ||
it is never cudagraph supported.""" | ||
|
||
NEVER = 0 | ||
"""NO cudagraph support""" | ||
PURE_DECODE_ONLY = 1 | ||
"""Cudagraph supported for pure decode, need to run without | ||
cudagraph for mixed prefill-decode batches""" | ||
ALWAYS = 2 | ||
"""Cudagraph always supported""" | ||
|
||
|
||
class AttentionMetadataBuilder(abc.ABC, Generic[M]): | ||
# Does this backend/builder support CUDA Graphs for attention. | ||
full_cudagraph_supported: ClassVar[bool] = False | ||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ | ||
AttentionCGSupport.NEVER | ||
|
||
@abstractmethod | ||
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ | |
is_pin_memory_available, round_up) | ||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend | ||
from vllm.v1.attention.backends.utils import ( | ||
AttentionMetadataBuilder, CommonAttentionMetadata, | ||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, | ||
make_local_attention_virtual_batches) | ||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||
from vllm.v1.kv_cache_interface import (AttentionSpec, | ||
|
@@ -2527,12 +2527,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: | |
self.device, | ||
) | ||
|
||
if (self.full_cuda_graph | ||
and not attn_metadata_builder_i.full_cudagraph_supported): | ||
raise ValueError( | ||
f"Full CUDAGraph not supported for " | ||
f"{attn_backend_i.__name__}. Turn off CompilationConfig." | ||
f"full_cuda_graph or use a different attention backend.") | ||
if self.full_cuda_graph: | ||
if attn_metadata_builder_i.attn_cudagraph_support == \ | ||
AttentionCGSupport.NEVER: | ||
raise ValueError( | ||
f"Full CUDAGraph not supported for " | ||
f"{attn_backend_i.__name__}. Turn off " | ||
f"CompilationConfig.full_cuda_graph or use a " | ||
f" different attention backend.") | ||
if attn_metadata_builder_i.attn_cudagraph_support == \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Could you add a comment here explaining what's going on and why it's necessary? Something like |
||
AttentionCGSupport.PURE_DECODE_ONLY: | ||
self.cudagraph_batch_sizes = [ | ||
size for size in self.cudagraph_batch_sizes | ||
if size <= self.scheduler_config.max_num_seqs | ||
] | ||
|
||
self.attn_backends.append(attn_backend_i) | ||
self.attn_metadata_builders.append(attn_metadata_builder_i) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -300,11 +300,16 @@ def compile_or_warm_up_model(self) -> None: | |
if get_pp_group().is_last_rank: | ||
max_num_reqs = min(self.scheduler_config.max_num_seqs, | ||
self.scheduler_config.max_num_batched_tokens) | ||
# activate building attn_metadata for this dummy run to avoid | ||
# potential illegal memory access for full cudagraph relay. | ||
attn_cudagraph = self.compilation_config.full_cuda_graph and\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey! @SageMoore, Thank you for the questions!
I think they are not related. I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect. |
||
not self.model_config.enforce_eager | ||
|
||
# We skip EPLB here since we don't want to record dummy metrics | ||
hidden_states, last_hidden_states = \ | ||
self.model_runner._dummy_run( | ||
num_tokens=max_num_reqs, | ||
capture_attn_cudagraph=attn_cudagraph, | ||
skip_eplb=True, | ||
) | ||
if self.model_runner.is_pooling_model: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.