From 92b1733e2c8ae1629894f0f38fa72cac39dadaad Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Wed, 25 Jun 2025 13:36:04 +0800 Subject: [PATCH 01/33] FA2 and FlashInfer Full cuda graph support Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/backends.py | 16 ++- vllm/compilation/base_piecewise_backend.py | 43 ++++++ vllm/compilation/cuda_piecewise_backend.py | 152 ++++++++++++++++++++- vllm/config.py | 23 +++- vllm/forward_context.py | 12 +- vllm/platforms/cuda.py | 4 + vllm/platforms/interface.py | 8 ++ vllm/v1/attention/backends/flash_attn.py | 11 +- vllm/v1/attention/backends/flashinfer.py | 149 +++++++++++++++++--- vllm/v1/worker/gpu_model_runner.py | 106 +++++++++++--- 10 files changed, 460 insertions(+), 64 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a2bb053cec4..fb9a6c94320 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -563,10 +563,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - if not self.compilation_config.use_cudagraph or \ - not self.compilation_config.cudagraph_copy_inputs: - return self.split_gm - # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode fake_mode = detect_fake_mode() @@ -585,6 +581,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: any(is_symbolic(d) for d in x.size()) ] + if self.compilation_config.full_cuda_graph: + assert self.compilation_config.use_cudagraph, \ + "full_cuda_graph mode requires use_cudagraph to be True" + fullgraph_wrapper = resolve_obj_by_qualname( + current_platform.get_fullgraph_wrapper_cls()) + self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, + self.graph_pool, self.sym_tensor_indices) + + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: + return self.split_gm + # compiler managed cudagraph input buffers # we assume the first run with symbolic shapes # has the maximum size among all the tensors diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py index 4d7aeeb4d03..a728f6f3724 100644 --- a/vllm/compilation/base_piecewise_backend.py +++ b/vllm/compilation/base_piecewise_backend.py @@ -70,3 +70,46 @@ def __call__(self, *args) -> Any: or a replayed static graph. """ raise NotImplementedError + + +class AbstractFullgraphWrapper(Protocol): + """ + FullgraphWrapper interface that allows platforms to wrap the piecewise graph + to be viewed or captured as a full graph. + """ + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, sym_shape_indices: list[int], **kwargs): + """ + Initializes the FullgraphWrapper class with compilation and + execution-related configurations. + + Args: + graph (fx.GraphModule): The graph represented in fx. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + sym_shape_indices (list[int]): + Indices of symbolic shape. + + Keyword Args: + kwargs: Additional keyword arguments reserved for future + extensions or custom platforms. + + """ + raise NotImplementedError + + def __call__(self, *args) -> Any: + """ + Executes the wrapped graph for given input args. + + Args: + *args: Variable length input arguments to be passed into the + graph. The symbolic shape is expected to be in position + `sym_shape_indices[0]`. + + Returns: + Any: Output of the executed wrapped graph. + """ + raise NotImplementedError diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c49ea6cc10..f9d01c604f2 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -37,6 +37,8 @@ class ConcreteSizeEntry: # during capture, and check if they are the same during replay input_addresses: Optional[list[int]] = None + usage_type: Optional[str] = None + class CUDAPiecewiseBackend: @@ -96,6 +98,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, runtime_shape=shape, need_to_compile=shape in self.compile_sizes, use_cudagraph=shape in self.cudagraph_capture_sizes, + usage_type="piecewise(general)", # for logging only ) def check_for_ending_compilation(self): @@ -139,9 +142,12 @@ def __call__(self, *args) -> Any: self.check_for_ending_compilation() # Skip CUDA graphs if this entry doesn't use them OR - # if we're supposed to skip them globally - skip_cuda_graphs = get_forward_context().skip_cuda_graphs - if not entry.use_cudagraph or skip_cuda_graphs: + # if we're supposed to treat the piecewise graphs as a whole, + # which implies forward_context.skip_attention_cuda_graphs is False. + # In the latter case, we rely on a wrapper class to capture + # the full cudagraph outside the fx graph. + skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs + if not entry.use_cudagraph or not skip_attention_cuda_graphs: return entry.runnable(*args) if entry.cudagraph is None: @@ -149,9 +155,10 @@ def __call__(self, *args) -> Any: entry.num_finished_warmup += 1 if self.is_first_graph: logger.debug( - "Warming up %s/%s for shape %s", + "Warming up %s/%s of %s usage for shape %s", entry.num_finished_warmup, self.compilation_config.cudagraph_num_of_warmups, + entry.usage_type, runtime_shape) return entry.runnable(*args) @@ -159,7 +166,8 @@ def __call__(self, *args) -> Any: # Since we capture cudagraph for many different shapes and # capturing is fast, we don't need to log it for every shape. # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", + logger.debug("Capturing a cudagraph of %s usage for shape %s", + entry.usage_type, runtime_shape) input_addresses = [ @@ -216,3 +224,137 @@ def __call__(self, *args) -> Any: entry.cudagraph.replay() return entry.output + + +class FullCudagraphWrapper: + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, sym_shape_indices: list[int], + ): + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.sym_shape_indices = sym_shape_indices + + self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + self.first_run_finished = False + + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {} + + + for shape in self.cudagraph_capture_sizes: + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=False, + use_cudagraph=True, + usage_type="general", + ) + if self.separate_attention_routine: + self.concrete_size_entries_decode[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=False, + use_cudagraph=True, + usage_type="decode", + ) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + return self.graph(*args) + list_args = list(args) + runtime_shape = list_args[self.sym_shape_indices[0]].shape[0] + forward_context = get_forward_context() + + if forward_context.skip_attention_cuda_graphs: + # turn back to piecewise cudagraphs backend, which is responsible + # for capturing and running the piecewise cudagraphs. + return self.graph(*args) + + # if not skip, the fx graph and its sub-graphs will only be supposed to + # eagerly run the compiled graphs, which should be cudagraph capturable + # as a whole. + + concrete_size_entries = self.concrete_size_entries # default as general usage + if self.separate_attention_routine and forward_context.is_pure_decoding: + concrete_size_entries = self.concrete_size_entries_decode + + if not runtime_shape in concrete_size_entries: + # we don't need to do anything for this shape. + return self.graph(*args) + + entry = concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.graph + + if not entry.use_cudagraph: + return entry.runnable(*args) + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + logger.debug( + "Warming up %s/%s of %s usage for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + entry.usage_type, + runtime_shape) + return entry.runnable(*args) + + + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + + logger.debug("Capturing a cudagraph of %s usage for shape %s", + entry.usage_type, + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.cudagraph.replay() + return entry.output \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 4333dcd3b8a..bcf3052d9d8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3974,6 +3974,14 @@ class CompilationConfig: splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models.""" + separate_attention_routine: bool = False + """ + Enable a distinct attention calls routine under an attention backend for full + cuda graph capturing. This is because some attention backends like FlashMLA, + FlashInfer, FA2, etc. implement different branches for mix prefill-decode and + pure decode cases. This flag enables us to potentially capture the cudagraph + separately for each branch. + """ pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -4172,13 +4180,16 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if self.splitting_ops and self.full_cuda_graph: - raise ValueError("full_cuda_graph cannot be used together with " - "splitting_ops, as Full CUDA graph will override " - f"the splitting_ops: {self.splitting_ops}") - + # NOTE: When full_cuda_graph is True, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx graph, + # we keep the piecewise fx graph structure but capture the full + # cudagraph outside the fx graph. This reduces some cpu overhead when + # the runtime batch_size is not cudagraph captured. This is only + # supported for separate_attention_routine. + if self.separate_attention_routine: + assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True" if not self.splitting_ops: - self.splitting_ops = [] if self.full_cuda_graph else [ + self.splitting_ops = [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..8b7a187a877 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -94,7 +94,11 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - skip_cuda_graphs: bool = False + # determine whether to use a full cudagraph for attention or piecewise + # cudagraphs that skip the attention part. By default true, we use piecewise + # cudagraphs. + skip_attention_cuda_graphs: bool = True, + is_pure_decoding: bool = False _forward_context: Optional[ForwardContext] = None @@ -115,7 +119,8 @@ def set_forward_context( virtual_engine: int = 0, num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False, + skip_attention_cuda_graphs: bool = True, + is_pure_decoding: bool = False, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -140,7 +145,8 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, - skip_cuda_graphs=skip_cuda_graphs, + skip_attention_cuda_graphs=skip_attention_cuda_graphs, + is_pure_decoding=is_pure_decoding, ) try: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 879d094f657..de99aee9fea 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -370,6 +370,10 @@ def use_custom_allreduce(cls) -> bool: @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + + @classmethod + def get_fullgraph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f962fafabf5..2c18616c8bc 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -531,6 +531,14 @@ def get_piecewise_backend_cls(cls) -> str: Get piecewise backend class for piecewise graph. """ return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + + @classmethod + def get_fullgraph_wrapper_cls(cls) -> str: + """ + Get fullgraph wrapper class for fullgraph static graph. + """ + return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa + @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4ad7178374b..7fe1ec2d17a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -139,7 +139,7 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 + full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() >= 2 def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -158,9 +158,7 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_schedule = (get_flash_attn_version() == 3) self.use_full_cuda_graph = compilation_config.full_cuda_graph - if self.use_full_cuda_graph and not self.aot_schedule: - raise ValueError("Full CUDA graph mode requires AOT scheduling, " - "which requires FlashAttention 3.") + self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1, dtype=torch.int32, device=self.runner.device) @@ -299,8 +297,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len=max_seq_len, causal=True) - if self.use_full_cuda_graph: - assert scheduler_metadata is not None + if scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n].copy_(scheduler_metadata, non_blocking=True) @@ -332,7 +329,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) + # Full CUDA Graph always supported return True def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 03a2ed7139c..01bec1fd18a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, ClassVar import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, @@ -218,22 +218,43 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - + full_cudagraph_supported: ClassVar[bool] = True def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): self.runner = runner + self.vllm_config = runner.vllm_config self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode + self._decode_wrapper = None # Wrapper for decode (general shape) + self.enable_cuda_graph = self.vllm_config.compilation_config.full_cuda_graph + if self.enable_cuda_graph: + # For full cudagraph capture, one `decode_wrapper` for each batch + # size is needed for FlashInfer. + self._decode_wrappers_cudagraph: dict[int, BatchDecodeWithPagedKVCacheWrapper] = {} + self._decode_cudagraph_max_bs = min(runner.max_num_reqs, + runner.cudagraph_batch_sizes[-1]) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = runner.vllm_config self.kv_cache_spec = kv_cache_spec self.block_table = block_table + # Preparing persistent buffers + self.paged_kv_indptr = torch.zeros( + self.runner.max_num_reqs + 1, + dtype=torch.int32, + device=self.runner.device) + self.paged_kv_indices = torch.zeros( + block_table.get_device_tensor().numel(), # max num pages possible + dtype=torch.int32, device=self.runner.device) + self.paged_kv_last_page_len = torch.zeros( + self.runner.max_num_reqs, + dtype=torch.int32, device=self.runner.device) + def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: # We now want to reorder the batch so that the "decode" requests are and @@ -307,19 +328,44 @@ def _get_prefill_wrapper(self): self._get_workspace_buffer(), get_kv_cache_layout()) return self._prefill_wrapper - def _get_decode_wrapper(self): - if self._decode_wrapper is None: + def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): + if use_cudagraph: + decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) + else: + decode_wrapper = self._decode_wrapper + + if decode_wrapper is None: num_qo_heads = (self.runner.model_config.get_num_attention_heads( self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + + if use_cudagraph: + paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indices = self.paged_kv_indices + paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] + else: + paged_kv_indptr = None + paged_kv_indices = None + paged_kv_last_page_len = None + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), get_kv_cache_layout(), + use_cuda_graph=use_cudagraph, + paged_kv_indptr_buffer=paged_kv_indptr, + paged_kv_indices_buffer=paged_kv_indices, + paged_kv_last_page_len_buffer=paged_kv_last_page_len, use_tensor_cores=use_tensor_cores) - return self._decode_wrapper + + # save the decode wrapper + if use_cudagraph: + self._decode_wrappers_cudagraph[batch_size] = decode_wrapper + else: + self._decode_wrapper = decode_wrapper + + return decode_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: @@ -395,11 +441,29 @@ def _plan(self, attn_metadata: FlashInferMetadata): ) if self._num_decodes > 0: - attn_metadata.decode_wrapper = self._get_decode_wrapper() + pure_decode = self._num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = (self.enable_cuda_graph and pure_decode and \ + self._num_decodes <= self._decode_cudagraph_max_bs) + if use_cudagraph: + num_input_tokens_decode = self.vllm_config.pad_for_cudagraph( + self._num_decodes) + else: + num_input_tokens_decode = self._num_decodes + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens_decode, use_cudagraph) + # TODO: Override flashinfer's plan function to avoid some + # host-to-device copy overhead. attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:self._num_decodes + 1], - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:self._num_decodes], + # NOTE: Use the persistent buffer with padding length, + # instead of the same address but chunked length buffers in + # the atten_metadata. This is to be compatible with + # FlashInfer's decode_wrapper when using cudagraph. + self.paged_kv_indptr[:num_input_tokens_decode + 1], + self.paged_kv_indices if use_cudagraph else \ + attn_metadata.paged_kv_indices, + self.paged_kv_last_page_len[:num_input_tokens_decode], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -426,9 +490,17 @@ def build(self, common_prefix_len: int, device = self.runner.device qo_indptr = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] - slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -462,6 +534,12 @@ def build(self, common_prefix_len: int, device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] + num_actual_pages = paged_kv_indices.size(0) + self.paged_kv_indices[:num_actual_pages].copy_( + paged_kv_indices, non_blocking=True) + # Fill the remaining paged_kv_last_page_len with 1. This is because + # flashinfer treats 0 as a full page instead of empty. + self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -469,17 +547,26 @@ def build(self, common_prefix_len: int, device=block_table_bounds.device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) + self.paged_kv_indptr[:1+num_reqs].copy_( + paged_kv_indptr, non_blocking=True) + # make sure self.paged_kv_indptr is not decreasing + self.paged_kv_indptr[1+num_reqs:].fill_( + paged_kv_indptr[-1]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + self.paged_kv_last_page_len[:num_reqs].copy_( + paged_kv_last_page_len, non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_( + 1) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=self.paged_kv_indptr[:1+num_reqs], + paged_kv_indices=self.paged_kv_indices[:num_actual_pages], + paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], num_qo_heads=self.runner.num_query_heads, num_kv_heads=self.kv_cache_spec.num_kv_heads, head_dim=self.kv_cache_spec.head_size, @@ -501,6 +588,34 @@ def build(self, common_prefix_len: int, self._plan(attn_metadata) return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with FlashInfer. + """ + m = common_attn_metadata + m.query_start_loc.copy_(torch.arange(m.num_actual_tokens+1, + dtype=torch.int32, + device=self.runner.device), + non_blocking=True) + assert m.num_reqs == m.num_actual_tokens, \ + "FlashInfer only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + # Update state usually set in reorder_batch. + self._num_decodes = m.num_reqs + self._num_decode_tokens = m.num_actual_tokens + self._num_prefills = 0 + self._num_prefill_tokens = 0 + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.runner.model_config.dtype: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf243..c701080cdfb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, Literal import numpy as np import torch @@ -1332,9 +1332,15 @@ def execute_model( num_input_tokens, intermediate_tensors, True) # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, but we - # compiled with full CUDA graphs, we have to skip them entirely. - skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + # If attention doesn't support CUDA Graphs for this batch, we skip them, + # and turn back to the piecewise CUDA graphs. Or if full_cuda_graph is + # False, we always turn to the piecewise CUDA graphs. + skip_attention_cuda_graphs = not attention_cuda_graphs \ + if self.full_cuda_graph else True + # Note: When skip_attention_cuda_graphs is always False and + # compilition_config.separate_attention_routine is True, as in FA2, + # this flag helps to determine the correct routine to run for the full cudagraph. + is_pure_decoding = num_scheduled_tokens == self.input_batch.num_reqs # Run the model. # Use persistent buffers for CUDA graphs. @@ -1343,7 +1349,8 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, + skip_attention_cuda_graphs=skip_attention_cuda_graphs, + is_pure_decoding=is_pure_decoding, ): self.maybe_setup_kv_connector(scheduler_output) @@ -1886,7 +1893,8 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: bool = False, + capture_attn_cudagraph: bool | Literal["auto"] = False, + is_pure_decoding: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: # Padding for DP @@ -1906,9 +1914,17 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + + # [Bugfix] This lets FA2 to correctly activate the optimized routine + # for pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA. + max_query_len = 1 if is_pure_decoding else num_tokens attn_metadata: Optional[dict[str, Any]] = None - if capture_attn_cudagraph: + skip_attention_cuda_graphs = True + if capture_attn_cudagraph: + # Note: At this step, `capture_attn_cudagraph` should be True or "auto", + # but we always treat it as "auto". i.e., always let the attention backends + # to determine whether to capture the attention or not. attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] @@ -1924,17 +1940,27 @@ def _dummy_run( seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=num_tokens, + max_query_len=max_query_len, ) - - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + # If all attention backends can run in a cudagraph, we use a full + # cudagraph for attention. Otherwise, turn back to piecewise cudagraphs. + attention_cuda_graphs = all( + b.can_run_in_cudagraph(common_attn_metadata) + for b in self.attn_metadata_builders) + skip_attention_cuda_graphs = not attention_cuda_graphs \ + if self.full_cuda_graph else True + + if not skip_attention_cuda_graphs: + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + else: + attn_metadata = None # reset to None other than empty dict with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1967,7 +1993,9 @@ def _dummy_run( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + skip_attention_cuda_graphs=skip_attention_cuda_graphs, + is_pure_decoding=is_pure_decoding): outputs = model( input_ids=input_ids, positions=positions, @@ -2207,13 +2235,47 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): full_cg = self.full_cuda_graph - for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes), - desc="Capturing CUDA graphs", + + # If full_cuda_graph is true, automatically determine whether or not + # to capture the attention for the mix prefill-decode (general) phase, + # based on the attention backends. + capture_attn_cudagraph_general = "auto" if full_cg else False + + # Skip capturing batch sizes of 1 in mix prefill-decode if + # separate_attention_routine is on. As bs=1 can treat as a + # pure decode. + start_idx = 0 + if self.vllm_config.compilation_config.separate_attention_routine \ + and len(self.cudagraph_batch_sizes) > 0 \ + and self.cudagraph_batch_sizes[0] == 1: + start_idx = 1 + # Capture the mix prefill-decode (general usage) cudagraphs + for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes[start_idx:]), + desc="Capturing CUDA graphs (mix prefill-decode)", total=len(self.cudagraph_batch_sizes)): for _ in range( self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) + self._dummy_run(num_tokens, + capture_attn_cudagraph=capture_attn_cudagraph_general, + is_pure_decoding=False) + self._dummy_run(num_tokens, + capture_attn_cudagraph=capture_attn_cudagraph_general, + is_pure_decoding=False) + + if self.vllm_config.compilation_config.separate_attention_routine: + # Capture the pure decode cudagraphs. Typically a full cudagraph + + max_num_reqs = self.scheduler_config.max_num_seqs + decode_cudagraph_batch_sizes = [x for x in self.cudagraph_batch_sizes + if x <= max_num_reqs] + for num_tokens in tqdm(reversed(decode_cudagraph_batch_sizes), + desc="Capturing CUDA graphs (pure decode)", + total=len(decode_cudagraph_batch_sizes)): + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, + is_pure_decoding=True) + self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, + is_pure_decoding=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] From 58ce47753c72b9ae00a69cee0ddfab98b6424d6d Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Wed, 25 Jun 2025 13:37:05 +0800 Subject: [PATCH 02/33] fix the arch support in CMakeLists.txt to include 8.9 Signed-off-by: fhl <2410591650@qq.com> --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 402131b7a1e..de3939d910e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_ARCHS) # @@ -684,7 +684,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # From c2c5feaf4ea6e5015270f331d2691537b1e21dfe Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Wed, 25 Jun 2025 15:45:45 +0800 Subject: [PATCH 03/33] Refactors Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/cuda_piecewise_backend.py | 23 ++++++++++++---------- vllm/config.py | 15 ++++++++------ vllm/forward_context.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 18 ++++++++++------- vllm/v1/worker/gpu_model_runner.py | 3 ++- 5 files changed, 36 insertions(+), 25 deletions(-) diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index f9d01c604f2..8c2dff5d351 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -236,7 +236,9 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.graph_pool = graph_pool self.sym_shape_indices = sym_shape_indices - self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine + self.separate_attention_routine = ( + vllm_config.compilation_config.separate_attention_routine + ) self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" @@ -282,7 +284,7 @@ def __call__(self, *args) -> Any: # eagerly run the compiled graphs, which should be cudagraph capturable # as a whole. - concrete_size_entries = self.concrete_size_entries # default as general usage + concrete_size_entries = self.concrete_size_entries if self.separate_attention_routine and forward_context.is_pure_decoding: concrete_size_entries = self.concrete_size_entries_decode @@ -324,15 +326,16 @@ def __call__(self, *args) -> Any: entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() - with ExitStack() as stack: + with ExitStack(), \ + torch.cuda.graph(cudagraph, pool=self.graph_pool): # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. - output = weak_ref_tensors(output) + + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. + output = weak_ref_tensors(output) # here we always use weak ref for the output # to save memory diff --git a/vllm/config.py b/vllm/config.py index bcf3052d9d8..c24cdb6b726 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3976,11 +3976,11 @@ class CompilationConfig: performance benefits for smaller models.""" separate_attention_routine: bool = False """ - Enable a distinct attention calls routine under an attention backend for full - cuda graph capturing. This is because some attention backends like FlashMLA, - FlashInfer, FA2, etc. implement different branches for mix prefill-decode and - pure decode cases. This flag enables us to potentially capture the cudagraph - separately for each branch. + Enable a distinct attention calls routine under an attention backend for + full cuda graph capturing. This is because some attention backends like + FlashMLA, FlashInfer, FA2, etc. implement different branches for mix + prefill-decode and pure decode cases. This flag enables us to potentially + capture the cudagraph separately for each branch. """ pass_config: PassConfig = field(default_factory=PassConfig) @@ -4187,7 +4187,10 @@ def set_splitting_ops_for_v1(self): # the runtime batch_size is not cudagraph captured. This is only # supported for separate_attention_routine. if self.separate_attention_routine: - assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True" + assert self.full_cuda_graph, ( + "separate_attention_routine requires " + "full_cuda_graph to be True" + ) if not self.splitting_ops: self.splitting_ops = [ "vllm.unified_attention", diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 8b7a187a877..94e8749f135 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -97,7 +97,7 @@ class ForwardContext: # determine whether to use a full cudagraph for attention or piecewise # cudagraphs that skip the attention part. By default true, we use piecewise # cudagraphs. - skip_attention_cuda_graphs: bool = True, + skip_attention_cuda_graphs: bool = True is_pure_decoding: bool = False diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 01bec1fd18a..c5b2b0c6d76 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -227,14 +227,17 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode self._decode_wrapper = None # Wrapper for decode (general shape) - self.enable_cuda_graph = self.vllm_config.compilation_config.full_cuda_graph + self.enable_cuda_graph = ( + self.vllm_config.compilation_config.full_cuda_graph + ) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. - self._decode_wrappers_cudagraph: dict[int, BatchDecodeWithPagedKVCacheWrapper] = {} - self._decode_cudagraph_max_bs = min(runner.max_num_reqs, - runner.cudagraph_batch_sizes[-1]) - + self._decode_wrappers_cudagraph: dict[int, + BatchDecodeWithPagedKVCacheWrapper] = {} + self._decode_cudagraph_max_bs = min( + runner.max_num_reqs, runner.cudagraph_batch_sizes[-1]) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -446,8 +449,9 @@ def _plan(self, attn_metadata: FlashInferMetadata): use_cudagraph = (self.enable_cuda_graph and pure_decode and \ self._num_decodes <= self._decode_cudagraph_max_bs) if use_cudagraph: - num_input_tokens_decode = self.vllm_config.pad_for_cudagraph( - self._num_decodes) + num_input_tokens_decode = ( + self.vllm_config.pad_for_cudagraph(self._num_decodes) + ) else: num_input_tokens_decode = self._num_decodes diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c701080cdfb..a368b8c288c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1339,7 +1339,8 @@ def execute_model( if self.full_cuda_graph else True # Note: When skip_attention_cuda_graphs is always False and # compilition_config.separate_attention_routine is True, as in FA2, - # this flag helps to determine the correct routine to run for the full cudagraph. + # this flag helps to determine the correct routine for the full + # cudagraph. is_pure_decoding = num_scheduled_tokens == self.input_batch.num_reqs # Run the model. From 1606880c41b25db66999b3994895ce6c063ae0db Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Wed, 25 Jun 2025 16:52:53 +0800 Subject: [PATCH 04/33] refactors Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/backends.py | 3 ++- vllm/compilation/cuda_piecewise_backend.py | 11 +++++------ vllm/v1/attention/backends/flashinfer.py | 1 - vllm/v1/worker/gpu_model_runner.py | 17 +++++++++-------- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index fb9a6c94320..60eb4320bf1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -587,7 +587,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: fullgraph_wrapper = resolve_obj_by_qualname( current_platform.get_fullgraph_wrapper_cls()) self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, - self.graph_pool, self.sym_tensor_indices) + self.graph_pool, + self.sym_tensor_indices) if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c2dff5d351..d098bec431c 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -143,10 +143,10 @@ def __call__(self, *args) -> Any: # Skip CUDA graphs if this entry doesn't use them OR # if we're supposed to treat the piecewise graphs as a whole, - # which implies forward_context.skip_attention_cuda_graphs is False. - # In the latter case, we rely on a wrapper class to capture - # the full cudagraph outside the fx graph. - skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs + # In the latter case, forward_context.skip_attention_cuda_graphs + # is False, and we rely on a wrapper class to capture the full + # cudagraph outside the fx graph. + skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs #noqa if not entry.use_cudagraph or not skip_attention_cuda_graphs: return entry.runnable(*args) @@ -307,8 +307,7 @@ def __call__(self, *args) -> Any: "Warming up %s/%s of %s usage for shape %s", entry.num_finished_warmup, self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, - runtime_shape) + entry.usage_type, runtime_shape) return entry.runnable(*args) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c5b2b0c6d76..68c66621e06 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -225,7 +225,6 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self.vllm_config = runner.vllm_config self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append - self._decode_wrapper = None # Wrapper for decode self._decode_wrapper = None # Wrapper for decode (general shape) self.enable_cuda_graph = ( self.vllm_config.compilation_config.full_cuda_graph diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a368b8c288c..eac4ba07467 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1894,7 +1894,7 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: bool | Literal["auto"] = False, + capture_attn_cudagraph: Union[bool, Literal["auto"]] = False, is_pure_decoding: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1923,9 +1923,10 @@ def _dummy_run( attn_metadata: Optional[dict[str, Any]] = None skip_attention_cuda_graphs = True if capture_attn_cudagraph: - # Note: At this step, `capture_attn_cudagraph` should be True or "auto", - # but we always treat it as "auto". i.e., always let the attention backends - # to determine whether to capture the attention or not. + # Note: At this step, `capture_attn_cudagraph` should be True or + # "auto", but we always treat it as "auto". i.e., always let the + # attention backends to determine whether to capture the attention + # or not. attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] @@ -1944,7 +1945,7 @@ def _dummy_run( max_query_len=max_query_len, ) # If all attention backends can run in a cudagraph, we use a full - # cudagraph for attention. Otherwise, turn back to piecewise cudagraphs. + # cudagraph for attention. Otherwise, back to piecewise cudagraphs. attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) @@ -2237,9 +2238,9 @@ def capture_model(self) -> None: with graph_capture(device=self.device): full_cg = self.full_cuda_graph - # If full_cuda_graph is true, automatically determine whether or not - # to capture the attention for the mix prefill-decode (general) phase, - # based on the attention backends. + # If full_cuda_graph is true, automatically determine whether or + # not to capture the attention for the mix prefill-decode (general) + # phase, based on the attention backends. capture_attn_cudagraph_general = "auto" if full_cg else False # Skip capturing batch sizes of 1 in mix prefill-decode if From 7c5df45f3ce06cdc3e3f8c6ed43c19f930eba95f Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 25 Jun 2025 10:03:40 +0000 Subject: [PATCH 05/33] refactor Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 4 +- vllm/compilation/base_piecewise_backend.py | 2 +- vllm/compilation/cuda_piecewise_backend.py | 63 +++++++++--------- vllm/config.py | 5 +- vllm/forward_context.py | 4 +- vllm/platforms/cuda.py | 2 +- vllm/platforms/interface.py | 3 +- vllm/v1/attention/backends/flash_attn.py | 1 - vllm/v1/attention/backends/flashinfer.py | 77 +++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 68 ++++++++++--------- 10 files changed, 117 insertions(+), 112 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 60eb4320bf1..542869687ab 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -587,13 +587,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: fullgraph_wrapper = resolve_obj_by_qualname( current_platform.get_fullgraph_wrapper_cls()) self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, - self.graph_pool, + self.graph_pool, self.sym_tensor_indices) if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: return self.split_gm - + # compiler managed cudagraph input buffers # we assume the first run with symbolic shapes # has the maximum size among all the tensors diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py index a728f6f3724..854c9146543 100644 --- a/vllm/compilation/base_piecewise_backend.py +++ b/vllm/compilation/base_piecewise_backend.py @@ -99,7 +99,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, """ raise NotImplementedError - + def __call__(self, *args) -> Any: """ Executes the wrapped graph for given input args. diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index d098bec431c..5e59f6ab4cc 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -144,9 +144,10 @@ def __call__(self, *args) -> Any: # Skip CUDA graphs if this entry doesn't use them OR # if we're supposed to treat the piecewise graphs as a whole, # In the latter case, forward_context.skip_attention_cuda_graphs - # is False, and we rely on a wrapper class to capture the full + # is False, and we rely on a wrapper class to capture the full # cudagraph outside the fx graph. - skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs #noqa + skip_attention_cuda_graphs = get_forward_context( + ).skip_attention_cuda_graphs if not entry.use_cudagraph or not skip_attention_cuda_graphs: return entry.runnable(*args) @@ -158,8 +159,7 @@ def __call__(self, *args) -> Any: "Warming up %s/%s of %s usage for shape %s", entry.num_finished_warmup, self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, - runtime_shape) + entry.usage_type, runtime_shape) return entry.runnable(*args) if self.is_first_graph: @@ -167,8 +167,7 @@ def __call__(self, *args) -> Any: # capturing is fast, we don't need to log it for every shape. # We only log it in the debug mode. logger.debug("Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, - runtime_shape) + entry.usage_type, runtime_shape) input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) @@ -227,9 +226,14 @@ def __call__(self, *args) -> Any: class FullCudagraphWrapper: - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, sym_shape_indices: list[int], - ): + + def __init__( + self, + graph: fx.GraphModule, + vllm_config: VllmConfig, + graph_pool: Any, + sym_shape_indices: list[int], + ): self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -237,21 +241,19 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.sym_shape_indices = sym_shape_indices self.separate_attention_routine = ( - vllm_config.compilation_config.separate_attention_routine - ) + vllm_config.compilation_config.separate_attention_routine) self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" self.first_run_finished = False self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {} - for shape in self.cudagraph_capture_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, @@ -278,17 +280,17 @@ def __call__(self, *args) -> Any: if forward_context.skip_attention_cuda_graphs: # turn back to piecewise cudagraphs backend, which is responsible # for capturing and running the piecewise cudagraphs. - return self.graph(*args) - - # if not skip, the fx graph and its sub-graphs will only be supposed to + return self.graph(*args) + + # if not skip, the fx graph and its sub-graphs will only be supposed to # eagerly run the compiled graphs, which should be cudagraph capturable # as a whole. - - concrete_size_entries = self.concrete_size_entries + + concrete_size_entries = self.concrete_size_entries if self.separate_attention_routine and forward_context.is_pure_decoding: concrete_size_entries = self.concrete_size_entries_decode - if not runtime_shape in concrete_size_entries: + if runtime_shape not in concrete_size_entries: # we don't need to do anything for this shape. return self.graph(*args) @@ -303,21 +305,18 @@ def __call__(self, *args) -> Any: if entry.cudagraph is None: if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa entry.num_finished_warmup += 1 - logger.debug( - "Warming up %s/%s of %s usage for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, runtime_shape) + logger.debug("Warming up %s/%s of %s usage for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + entry.usage_type, runtime_shape) return entry.runnable(*args) - # Since we capture cudagraph for many different shapes and # capturing is fast, we don't need to log it for every shape. # We only log it in the debug mode. - + logger.debug("Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, - runtime_shape) + entry.usage_type, runtime_shape) input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) @@ -328,12 +327,12 @@ def __call__(self, *args) -> Any: with ExitStack(), \ torch.cuda.graph(cudagraph, pool=self.graph_pool): # mind-exploding: carefully manage the reference and memory. - + # `output` is managed by pytorch's cudagraph pool output = entry.runnable(*args) # by converting it to weak ref, # the original `output` will immediately be released - # to save memory. + # to save memory. output = weak_ref_tensors(output) # here we always use weak ref for the output @@ -359,4 +358,4 @@ def __call__(self, *args) -> Any: ) entry.cudagraph.replay() - return entry.output \ No newline at end of file + return entry.output diff --git a/vllm/config.py b/vllm/config.py index 8dd57da8996..8e2c2eb7140 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4192,15 +4192,14 @@ def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called # NOTE: When full_cuda_graph is True, instead of setting an empty # list and capture the full cudagraph inside the flattened fx graph, - # we keep the piecewise fx graph structure but capture the full + # we keep the piecewise fx graph structure but capture the full # cudagraph outside the fx graph. This reduces some cpu overhead when # the runtime batch_size is not cudagraph captured. This is only # supported for separate_attention_routine. if self.separate_attention_routine: assert self.full_cuda_graph, ( "separate_attention_routine requires " - "full_cuda_graph to be True" - ) + "full_cuda_graph to be True") if not self.splitting_ops: self.splitting_ops = [ "vllm.unified_attention", diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 94e8749f135..6440af712a8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -94,8 +94,8 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - # determine whether to use a full cudagraph for attention or piecewise - # cudagraphs that skip the attention part. By default true, we use piecewise + # determine whether to use a full cudagraph for attention or piecewise + # cudagraphs that skip the attention part. By default true, we use piecewise # cudagraphs. skip_attention_cuda_graphs: bool = True is_pure_decoding: bool = False diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index de99aee9fea..eb329c342a2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -370,7 +370,7 @@ def use_custom_allreduce(cls) -> bool: @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa - + @classmethod def get_fullgraph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2c18616c8bc..3e88aafab53 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -531,7 +531,7 @@ def get_piecewise_backend_cls(cls) -> str: Get piecewise backend class for piecewise graph. """ return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa - + @classmethod def get_fullgraph_wrapper_cls(cls) -> str: """ @@ -539,7 +539,6 @@ def get_fullgraph_wrapper_cls(cls) -> str: """ return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa - @classmethod def stateless_init_device_torch_dist_pg( cls, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 28022870e4c..27c96f40db6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -167,7 +167,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, # by the kernel launching code. self.aot_schedule = False - # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 68c66621e06..0b689c24e28 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, @@ -219,6 +219,7 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): full_cudagraph_supported: ClassVar[bool] = True + def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): self.runner = runner @@ -227,13 +228,12 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) self.enable_cuda_graph = ( - self.vllm_config.compilation_config.full_cuda_graph - ) + self.vllm_config.compilation_config.full_cuda_graph) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. - self._decode_wrappers_cudagraph: dict[int, - BatchDecodeWithPagedKVCacheWrapper] = {} + self._decode_wrappers_cudagraph: dict[ + int, BatchDecodeWithPagedKVCacheWrapper] = {} self._decode_cudagraph_max_bs = min( runner.max_num_reqs, runner.cudagraph_batch_sizes[-1]) @@ -246,16 +246,16 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self.block_table = block_table # Preparing persistent buffers - self.paged_kv_indptr = torch.zeros( - self.runner.max_num_reqs + 1, - dtype=torch.int32, - device=self.runner.device) + self.paged_kv_indptr = torch.zeros(self.runner.max_num_reqs + 1, + dtype=torch.int32, + device=self.runner.device) self.paged_kv_indices = torch.zeros( - block_table.get_device_tensor().numel(), # max num pages possible - dtype=torch.int32, device=self.runner.device) - self.paged_kv_last_page_len = torch.zeros( - self.runner.max_num_reqs, - dtype=torch.int32, device=self.runner.device) + block_table.get_device_tensor().numel(), # max num pages possible + dtype=torch.int32, + device=self.runner.device) + self.paged_kv_last_page_len = torch.zeros(self.runner.max_num_reqs, + dtype=torch.int32, + device=self.runner.device) def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -330,9 +330,12 @@ def _get_prefill_wrapper(self): self._get_workspace_buffer(), get_kv_cache_layout()) return self._prefill_wrapper - def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): + def _get_decode_wrapper(self, + batch_size: int, + use_cudagraph: bool = False): if use_cudagraph: - decode_wrapper = self._decode_wrappers_cudagraph.get(batch_size, None) + decode_wrapper = self._decode_wrappers_cudagraph.get( + batch_size, None) else: decode_wrapper = self._decode_wrapper @@ -343,11 +346,12 @@ def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False): self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - + if use_cudagraph: paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] paged_kv_indices = self.paged_kv_indices - paged_kv_last_page_len = self.paged_kv_last_page_len[:batch_size] + paged_kv_last_page_len = self.paged_kv_last_page_len[: + batch_size] else: paged_kv_indptr = None paged_kv_indices = None @@ -449,15 +453,14 @@ def _plan(self, attn_metadata: FlashInferMetadata): self._num_decodes <= self._decode_cudagraph_max_bs) if use_cudagraph: num_input_tokens_decode = ( - self.vllm_config.pad_for_cudagraph(self._num_decodes) - ) + self.vllm_config.pad_for_cudagraph(self._num_decodes)) else: num_input_tokens_decode = self._num_decodes attn_metadata.decode_wrapper = self._get_decode_wrapper( - num_input_tokens_decode, use_cudagraph) + num_input_tokens_decode, use_cudagraph) # TODO: Override flashinfer's plan function to avoid some - # host-to-device copy overhead. + # host-to-device copy overhead. attn_metadata.decode_wrapper.plan( # NOTE: Use the persistent buffer with padding length, # instead of the same address but chunked length buffers in @@ -538,8 +541,8 @@ def build(self, common_prefix_len: int, < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] num_actual_pages = paged_kv_indices.size(0) - self.paged_kv_indices[:num_actual_pages].copy_( - paged_kv_indices, non_blocking=True) + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) # Fill the remaining paged_kv_last_page_len with 1. This is because # flashinfer treats 0 as a full page instead of empty. self.paged_kv_indices[num_actual_pages:].fill_(-1) @@ -550,24 +553,22 @@ def build(self, common_prefix_len: int, device=block_table_bounds.device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - self.paged_kv_indptr[:1+num_reqs].copy_( - paged_kv_indptr, non_blocking=True) + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) # make sure self.paged_kv_indptr is not decreasing - self.paged_kv_indptr[1+num_reqs:].fill_( - paged_kv_indptr[-1]) + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - self.paged_kv_last_page_len[:num_reqs].copy_( - paged_kv_last_page_len, non_blocking=True) - self.paged_kv_last_page_len[num_reqs:].fill_( - 1) + self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, + non_blocking=True) + self.paged_kv_last_page_len[num_reqs:].fill_(1) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, - paged_kv_indptr=self.paged_kv_indptr[:1+num_reqs], + paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs], paged_kv_indices=self.paged_kv_indices[:num_actual_pages], paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], num_qo_heads=self.runner.num_query_heads, @@ -591,7 +592,7 @@ def build(self, common_prefix_len: int, self._plan(attn_metadata) return attn_metadata - + def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata): """ @@ -599,9 +600,9 @@ def build_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with FlashInfer. """ m = common_attn_metadata - m.query_start_loc.copy_(torch.arange(m.num_actual_tokens+1, - dtype=torch.int32, - device=self.runner.device), + m.query_start_loc.copy_(torch.arange(m.num_actual_tokens + 1, + dtype=torch.int32, + device=self.runner.device), non_blocking=True) assert m.num_reqs == m.num_actual_tokens, \ "FlashInfer only supports decode-only full CUDAGraph capture. " \ @@ -615,7 +616,7 @@ def build_for_cudagraph_capture( self._num_prefills = 0 self._num_prefill_tokens = 0 return self.build(0, m) - + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: return common_attn_metadata.max_query_len == 1 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eac4ba07467..c613402fa2f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union, Literal +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np import torch @@ -1333,7 +1333,7 @@ def execute_model( # Some attention backends only support CUDA Graphs in pure decode. # If attention doesn't support CUDA Graphs for this batch, we skip them, - # and turn back to the piecewise CUDA graphs. Or if full_cuda_graph is + # and turn back to the piecewise CUDA graphs. Or if full_cuda_graph is # False, we always turn to the piecewise CUDA graphs. skip_attention_cuda_graphs = not attention_cuda_graphs \ if self.full_cuda_graph else True @@ -1915,14 +1915,14 @@ def _dummy_run( assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - + # [Bugfix] This lets FA2 to correctly activate the optimized routine # for pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA. - max_query_len = 1 if is_pure_decoding else num_tokens + max_query_len = 1 if is_pure_decoding else num_tokens attn_metadata: Optional[dict[str, Any]] = None skip_attention_cuda_graphs = True - if capture_attn_cudagraph: + if capture_attn_cudagraph: # Note: At this step, `capture_attn_cudagraph` should be True or # "auto", but we always treat it as "auto". i.e., always let the # attention backends to determine whether to capture the attention @@ -1962,7 +1962,7 @@ def _dummy_run( for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i else: - attn_metadata = None # reset to None other than empty dict + attn_metadata = None # reset to None other than empty dict with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1996,7 +1996,7 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_attention_cuda_graphs=skip_attention_cuda_graphs, + skip_attention_cuda_graphs=skip_attention_cuda_graphs, is_pure_decoding=is_pure_decoding): outputs = model( input_ids=input_ids, @@ -2243,40 +2243,48 @@ def capture_model(self) -> None: # phase, based on the attention backends. capture_attn_cudagraph_general = "auto" if full_cg else False - # Skip capturing batch sizes of 1 in mix prefill-decode if - # separate_attention_routine is on. As bs=1 can treat as a - # pure decode. - start_idx = 0 + # Skip capturing batch sizes of 1 in mix prefill-decode if + # separate_attention_routine is on. As bs=1 can treat as a + # pure decode. + start_idx = 0 if self.vllm_config.compilation_config.separate_attention_routine \ and len(self.cudagraph_batch_sizes) > 0 \ and self.cudagraph_batch_sizes[0] == 1: start_idx = 1 # Capture the mix prefill-decode (general usage) cudagraphs - for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes[start_idx:]), - desc="Capturing CUDA graphs (mix prefill-decode)", - total=len(self.cudagraph_batch_sizes)): + for num_tokens in tqdm( + reversed(self.cudagraph_batch_sizes[start_idx:]), + desc="Capturing CUDA graphs (mix prefill-decode)", + total=len(self.cudagraph_batch_sizes)): for _ in range( self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=capture_attn_cudagraph_general, - is_pure_decoding=False) - self._dummy_run(num_tokens, - capture_attn_cudagraph=capture_attn_cudagraph_general, - is_pure_decoding=False) + self._dummy_run( + num_tokens, + capture_attn_cudagraph=capture_attn_cudagraph_general, + is_pure_decoding=False) + self._dummy_run( + num_tokens, + capture_attn_cudagraph=capture_attn_cudagraph_general, + is_pure_decoding=False) if self.vllm_config.compilation_config.separate_attention_routine: # Capture the pure decode cudagraphs. Typically a full cudagraph - + max_num_reqs = self.scheduler_config.max_num_seqs - decode_cudagraph_batch_sizes = [x for x in self.cudagraph_batch_sizes - if x <= max_num_reqs] - for num_tokens in tqdm(reversed(decode_cudagraph_batch_sizes), - desc="Capturing CUDA graphs (pure decode)", - total=len(decode_cudagraph_batch_sizes)): - for _ in range(self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, - is_pure_decoding=True) - self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes if x <= max_num_reqs + ] + for num_tokens in tqdm( + reversed(decode_cudagraph_batch_sizes), + desc="Capturing CUDA graphs (pure decode)", + total=len(decode_cudagraph_batch_sizes)): + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, + is_pure_decoding=True) + self._dummy_run(num_tokens, + capture_attn_cudagraph=full_cg, is_pure_decoding=True) end_time = time.perf_counter() From c7a9424be261f565ab5941db0da8ec3894bc85ec Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:33:55 +0000 Subject: [PATCH 06/33] Add check for separate_attention_routine flag Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/flash_attn.py | 3 +++ vllm/v1/attention/backends/flashinfer.py | 1 + vllm/v1/attention/backends/mla/flashmla.py | 1 + vllm/v1/attention/backends/triton_attn.py | 1 + vllm/v1/attention/backends/utils.py | 5 +++- vllm/v1/worker/gpu_model_runner.py | 29 +++++++++++++++++----- 6 files changed, 33 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 27c96f40db6..fdd5bd50441 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -140,6 +140,9 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() >= 2 + # FlashAttn support a unified varlen fwd kernel for prefill-decode phase, so + # it's ok to either separate attention routine or not for both FA2 or 3. + force_separate_routine: ClassVar[Optional[bool]] = None def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 0b689c24e28..3a21dc9e7d4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -219,6 +219,7 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): full_cudagraph_supported: ClassVar[bool] = True + force_separate_routine: ClassVar[Optional[bool]] = True def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db..d49a87c8a57 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -55,6 +55,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only + force_separate_routine: ClassVar[Optional[bool]] = True def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4c5a1a755c1..eb0e40a699d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -74,6 +74,7 @@ class LocalAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True + force_separate_routine: ClassVar[Optional[bool]] = None def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8083f200260..9836a04fb58 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar import numpy as np import torch @@ -50,6 +50,9 @@ class CommonAttentionMetadata: class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. full_cudagraph_supported: ClassVar[bool] = False + # If full cudagraph support, select if this attention backend + # enforce separate rountine to be True, False or None (free). + force_separate_routine: ClassVar[Optional[bool]] = None @abstractmethod def build(self, common_prefix_len: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c613402fa2f..e861d2113dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2338,12 +2338,29 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: block_table_i, ) - if (self.full_cuda_graph - and not attn_metadata_builder_i.full_cudagraph_supported): - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off CompilationConfig." - f"full_cuda_graph or use a different attention backend.") + if self.full_cuda_graph: + if not attn_metadata_builder_i.full_cudagraph_supported: + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__}. Turn off " + f"CompilationConfig.full_cuda_graph or use a different" + f" attention backend.") + + # check if the attention backends enforce to have separate + # routines for mix prefill-decode and pure decode phase + if attn_metadata_builder_i.force_separate_routine is not None \ + and self.compilation_config.separate_attention_rountine\ + != attn_metadata_builder_i.force_separate_routine: + + expected = attn_metadata_builder_i.force_separate_routine + logger.warning_once( + f"Full CUDAGraph for {attn_backend_i.__name__}" + f"enforce CompilationConfig.separate_attention" + f"_rountine as: {expected}. Now set it to: " + f"{expected}.") + + self.compilation_config.separate_attention_rountine = \ + expected self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) From e8b929624bddb6b5868a3001069f2c526bb8a5c0 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Thu, 26 Jun 2025 17:14:44 +0800 Subject: [PATCH 07/33] fix typo error Signed-off-by: fhl <2410591650@qq.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e861d2113dc..2e6bdc02333 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2349,7 +2349,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # check if the attention backends enforce to have separate # routines for mix prefill-decode and pure decode phase if attn_metadata_builder_i.force_separate_routine is not None \ - and self.compilation_config.separate_attention_rountine\ + and self.compilation_config.separate_attention_routine\ != attn_metadata_builder_i.force_separate_routine: expected = attn_metadata_builder_i.force_separate_routine From a67c698c43e34b351df9534eda5da534fecbd40b Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 28 Jun 2025 06:06:12 +0000 Subject: [PATCH 08/33] refactors and rearchitect cuda graph logic Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/cuda_piecewise_backend.py | 261 +++++++++++---------- vllm/forward_context.py | 6 +- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/triton_attn.py | 1 - vllm/v1/worker/gpu_model_runner.py | 31 ++- 5 files changed, 149 insertions(+), 152 deletions(-) diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 5e59f6ab4cc..4feeeea9aa9 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -33,6 +33,8 @@ class ConcreteSizeEntry: cudagraph: Optional[torch.cuda.CUDAGraph] = None output: Optional[Any] = None + cudagraph_runnable: Optional[Callable] = None + # for cudagraph debugging, track the input addresses # during capture, and check if they are the same during replay input_addresses: Optional[list[int]] = None @@ -40,6 +42,108 @@ class ConcreteSizeEntry: usage_type: Optional[str] = None +class CUDAGraphWrapper: + """ + A wrapper class for cudagraphs functionality. + + This class creates a cudagraph runnable for a given `ConcreteSizeEntry`, + taking responsibility of capturing cudagraph and running the replay. + """ + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + def create_cudagraph_runnable( + self, entry: ConcreteSizeEntry, + cudagraph_runnable_config: dict[str, Any]) -> Any: + graph_pool = cudagraph_runnable_config["graph_pool"] + assert graph_pool is not None + debug_capturing = cudagraph_runnable_config.get( + "debug_capturing", True) + gc_disable = cudagraph_runnable_config.get("gc_disable", False) + weak_ref_output = cudagraph_runnable_config.get( + "weak_ref_output", True) + + def cudagraph_runnable(*args): + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if debug_capturing: + logger.debug( + "Warming up %s/%s of %s usage for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + entry.usage_type, entry.runtime_shape) + return entry.runnable(*args) + + if debug_capturing: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. We only log it in the debug mode. + logger.debug( + "Capturing a cudagraph of %s usage for shape %s", + entry.usage_type, entry.runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if gc_disable: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last + # graph will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during " + f"replay. Expected {entry.input_addresses}, got " + f"{new_input_addresses}") + + entry.cudagraph.replay() + return entry.output + + return cudagraph_runnable + + class CUDAPiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, @@ -72,6 +176,8 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) + self.is_full_graph = total_piecewise_compiles == 1 + self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) self.cudagraph_capture_sizes: set[int] = set( @@ -94,13 +200,17 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + usage_type = "full/general" if self.is_full_graph else \ + "piecewise/general" self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, need_to_compile=shape in self.compile_sizes, use_cudagraph=shape in self.cudagraph_capture_sizes, - usage_type="piecewise(general)", # for logging only + usage_type=usage_type, # for debug logging only ) + self.cudagraph_wrapper = CUDAGraphWrapper(vllm_config) + def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile @@ -151,78 +261,17 @@ def __call__(self, *args) -> Any: if not entry.use_cudagraph or not skip_attention_cuda_graphs: return entry.runnable(*args) - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s of %s usage for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output + if entry.cudagraph_runnable is None: + cudagraph_runnable_config = { + "graph_pool": self.graph_pool, + "debug_capturing": self.is_first_graph, + "gc_disable": not self.is_first_graph, + "weak_ref_output": self.is_last_graph, + } + entry.cudagraph_runnable = \ + self.cudagraph_wrapper.create_cudagraph_runnable(entry, + cudagraph_runnable_config) + return entry.cudagraph_runnable(*args) class FullCudagraphWrapper: @@ -269,6 +318,8 @@ def __init__( usage_type="decode", ) + self.cudagraph_wrapper = CUDAGraphWrapper(vllm_config) + def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True @@ -287,7 +338,7 @@ def __call__(self, *args) -> Any: # as a whole. concrete_size_entries = self.concrete_size_entries - if self.separate_attention_routine and forward_context.is_pure_decoding: + if self.separate_attention_routine and forward_context.is_pure_decode: concrete_size_entries = self.concrete_size_entries_decode if runtime_shape not in concrete_size_entries: @@ -302,60 +353,10 @@ def __call__(self, *args) -> Any: if not entry.use_cudagraph: return entry.runnable(*args) - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - logger.debug("Warming up %s/%s of %s usage for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, runtime_shape) - return entry.runnable(*args) - - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - - logger.debug("Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack(), \ - torch.cuda.graph(cudagraph, pool=self.graph_pool): - # mind-exploding: carefully manage the reference and memory. - - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) + if entry.cudagraph_runnable is None: + cudagraph_runnable_config = {"graph_pool": self.graph_pool} + entry.cudagraph_runnable = \ + self.cudagraph_wrapper.create_cudagraph_runnable(entry, + cudagraph_runnable_config) - entry.cudagraph.replay() - return entry.output + return entry.cudagraph_runnable(*args) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 6440af712a8..4db26b6d49f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -98,7 +98,7 @@ class ForwardContext: # cudagraphs that skip the attention part. By default true, we use piecewise # cudagraphs. skip_attention_cuda_graphs: bool = True - is_pure_decoding: bool = False + is_pure_decode: bool = False _forward_context: Optional[ForwardContext] = None @@ -120,7 +120,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_attention_cuda_graphs: bool = True, - is_pure_decoding: bool = False, + is_pure_decode: bool = False, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -146,7 +146,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decoding=is_pure_decoding, + is_pure_decode=is_pure_decode, ) try: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e20239eb957..d0436f3ee9e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -141,7 +141,7 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() >= 2 + full_cudagraph_supported: ClassVar[bool] = True # FlashAttn support a unified varlen fwd kernel for prefill-decode phase, so # it's ok to either separate attention routine or not for both FA2 or 3. force_separate_routine: ClassVar[Optional[bool]] = None diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index eb0e40a699d..4c5a1a755c1 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -74,7 +74,6 @@ class LocalAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - force_separate_routine: ClassVar[Optional[bool]] = None def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eb9662d2795..0fc48b39727 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1362,13 +1362,12 @@ def execute_model( # If attention doesn't support CUDA Graphs for this batch, we skip them, # and turn back to the piecewise CUDA graphs. Or if full_cuda_graph is # False, we always turn to the piecewise CUDA graphs. - skip_attention_cuda_graphs = not attention_cuda_graphs \ - if self.full_cuda_graph else True + skip_attention_cuda_graphs = not self.full_cuda_graph or not attention_cuda_graphs # noqa: E501 # Note: When skip_attention_cuda_graphs is always False and # compilition_config.separate_attention_routine is True, as in FA2, # this flag helps to determine the correct routine for the full # cudagraph. - is_pure_decoding = num_scheduled_tokens == self.input_batch.num_reqs + is_pure_decode = num_scheduled_tokens == self.input_batch.num_reqs # Run the model. # Use persistent buffers for CUDA graphs. @@ -1378,7 +1377,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decoding=is_pure_decoding, + is_pure_decode=is_pure_decode, ): self.maybe_setup_kv_connector(scheduler_output) @@ -1934,7 +1933,7 @@ def _dummy_run( self, num_tokens: int, capture_attn_cudagraph: Union[bool, Literal["auto"]] = False, - is_pure_decoding: bool = False, + is_pure_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1957,9 +1956,9 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - # [Bugfix] This lets FA2 to correctly activate the optimized routine - # for pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA. - max_query_len = 1 if is_pure_decoding else num_tokens + # This lets FA2 to correctly activate the optimized routine for + # pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA. + max_query_len = 1 if is_pure_decode else num_tokens attn_metadata: Optional[dict[str, Any]] = None skip_attention_cuda_graphs = True @@ -2038,7 +2037,7 @@ def _dummy_run( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decoding=is_pure_decoding): + is_pure_decode=is_pure_decode): outputs = model( input_ids=input_ids, positions=positions, @@ -2290,7 +2289,6 @@ def capture_model(self) -> None: with graph_capture(device=self.device): full_cg = self.full_cuda_graph - # If full_cuda_graph is true, automatically determine whether or # not to capture the attention for the mix prefill-decode (general) # phase, based on the attention backends. @@ -2304,9 +2302,9 @@ def capture_model(self) -> None: and len(self.cudagraph_batch_sizes) > 0 \ and self.cudagraph_batch_sizes[0] == 1: start_idx = 1 - + # We skip EPLB here since we don't want to record dummy metrics - + # Capture the mix prefill-decode (general usage) cudagraphs for num_tokens in tqdm( reversed(self.cudagraph_batch_sizes[start_idx:]), @@ -2317,12 +2315,12 @@ def capture_model(self) -> None: self._dummy_run( num_tokens, capture_attn_cudagraph=capture_attn_cudagraph_general, - is_pure_decoding=False, + is_pure_decode=False, skip_eplb=True) self._dummy_run( num_tokens, capture_attn_cudagraph=capture_attn_cudagraph_general, - is_pure_decoding=False, + is_pure_decode=False, skip_eplb=True) if self.vllm_config.compilation_config.separate_attention_routine: @@ -2340,14 +2338,13 @@ def capture_model(self) -> None: self.compilation_config.cudagraph_num_of_warmups): self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, - is_pure_decoding=True, + is_pure_decode=True, skip_eplb=True) self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, - is_pure_decoding=True, + is_pure_decode=True, skip_eplb=True) - end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time From da110afef29151caea4c7c8b2ac71176d718d8f1 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 28 Jun 2025 07:39:02 +0000 Subject: [PATCH 09/33] Refactors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/flashinfer.py | 9 +++------ vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 3a21dc9e7d4..bffe774b38a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -544,8 +544,6 @@ def build(self, common_prefix_len: int, num_actual_pages = paged_kv_indices.size(0) self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, non_blocking=True) - # Fill the remaining paged_kv_last_page_len with 1. This is because - # flashinfer treats 0 as a full page instead of empty. self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indptr = torch.cat([ @@ -564,6 +562,8 @@ def build(self, common_prefix_len: int, page_size, paged_kv_last_page_len) self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, non_blocking=True) + # Fill the remaining paged_kv_last_page_len with 1. This is because + # flashinfer treats 0 as a full page instead of empty. self.paged_kv_last_page_len[num_reqs:].fill_(1) attn_metadata = FlashInferMetadata( @@ -601,10 +601,7 @@ def build_for_cudagraph_capture( Currently, only decode is supported for full cudagraphs with FlashInfer. """ m = common_attn_metadata - m.query_start_loc.copy_(torch.arange(m.num_actual_tokens + 1, - dtype=torch.int32, - device=self.runner.device), - non_blocking=True) + assert m.num_reqs == m.num_actual_tokens, \ "FlashInfer only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0fc48b39727..86df975ee65 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2309,7 +2309,7 @@ def capture_model(self) -> None: for num_tokens in tqdm( reversed(self.cudagraph_batch_sizes[start_idx:]), desc="Capturing CUDA graphs (mix prefill-decode)", - total=len(self.cudagraph_batch_sizes)): + total=len(self.cudagraph_batch_sizes) - start_idx): for _ in range( self.compilation_config.cudagraph_num_of_warmups): self._dummy_run( From deaf0fe23c01b450f747623c288ac687eebb815f Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 28 Jun 2025 07:48:02 +0000 Subject: [PATCH 10/33] Delect one commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15c67073ba4..b1adeac586f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_ARCHS) # @@ -702,7 +702,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") # 9.0 for latest bf16 atomicAdd PTX - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # From 02ca154213e7e4e58783d346d509e2afe0806727 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 28 Jun 2025 16:58:13 +0000 Subject: [PATCH 11/33] Add support for force_no_split_graph Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 40 ++++++++++++++++------- vllm/v1/attention/backends/flash_attn.py | 2 ++ vllm/v1/attention/backends/triton_attn.py | 2 ++ vllm/v1/attention/backends/utils.py | 6 ++++ vllm/v1/worker/gpu_model_runner.py | 8 +++++ 5 files changed, 47 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 93c96facec0..edf2d4d440a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4066,6 +4066,14 @@ class CompilationConfig: prefill-decode and pure decode cases. This flag enables us to potentially capture the cudagraph separately for each branch. """ + force_no_split_graph: bool = False + """ + Enforce the fx graph to be a flattened full graph instead of a piecewise fx + graph with submodules (splited by attention ops, as the default behavior). + + Maintaining the full graph may offer benefits in some cases, e.g., enabling + attention-related custom inductor passes, such as attention+quant fusion. + """ pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -4264,21 +4272,31 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - # NOTE: When full_cuda_graph is True, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx graph, - # we keep the piecewise fx graph structure but capture the full - # cudagraph outside the fx graph. This reduces some cpu overhead when - # the runtime batch_size is not cudagraph captured. This is only - # supported for separate_attention_routine. if self.separate_attention_routine: assert self.full_cuda_graph, ( "separate_attention_routine requires " "full_cuda_graph to be True") - if not self.splitting_ops: - self.splitting_ops = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ] + + if self.force_no_split_graph: + assert self.full_cuda_graph, ( + "force_no_split_graph requires full_cuda_graph to be True") + assert not self.splitting_ops, ( + "force_no_split_graph cannot be used together with " + "splitting_ops is not empty. Please set splitting_ops" + "to [] if you want to use force_no_split_graph.") + self.splitting_ops = [] + else: + # NOTE: When full_cuda_graph is True, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture the + # full cudagraph outside the fx graph. This reduces some cpu + # overhead when the runtime batch_size is not cudagraph captured. + # see PR #20059. + if not self.splitting_ops: + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] @config diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d0436f3ee9e..8312961c693 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -146,6 +146,8 @@ class FlashAttentionMetadataBuilder( # it's ok to either separate attention routine or not for both FA2 or 3. force_separate_routine: ClassVar[Optional[bool]] = None + support_full_cudagraph_only: ClassVar[bool] = True + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): model_config = runner.model_config diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4c5a1a755c1..ef856257c38 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -75,6 +75,8 @@ class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True + support_full_cudagraph_only: ClassVar[bool] = True + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): self.runner = runner diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 9836a04fb58..5750070d018 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -54,6 +54,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # enforce separate rountine to be True, False or None (free). force_separate_routine: ClassVar[Optional[bool]] = None + # If the attention backend supports full cudagraph only. Which means, + # it supports full cudagraph for both prefill-decode phase and pure + # decode phase. Used to check if compilation_config.force_no_split_graph + # = True is valid when full cudagraph is enable. + support_full_cudagraph_only: ClassVar[bool] = False + @abstractmethod def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata) -> M: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86df975ee65..841fc2d662c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2404,6 +2404,14 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: f"CompilationConfig.full_cuda_graph or use a different" f" attention backend.") + if self.compilation_config.force_no_split_graph: + assert attn_metadata_builder_i.support_full_cudagraph_only, ( # noqa: E501 + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__} with " + f"CompilationConfig.force_no_split_graph=True. " + f"Turn off CompilationConfig.force_no_split_graph" + f"or use a different attention backend.") + # check if the attention backends enforce to have separate # routines for mix prefill-decode and pure decode phase if attn_metadata_builder_i.force_separate_routine is not None \ From 5108befa10d361e416ab34ef077228434fc5516c Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Sat, 5 Jul 2025 16:00:22 +0800 Subject: [PATCH 12/33] Huge refactors to separete cudagraph logic from vllm compilation Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/backends.py | 44 +-- vllm/compilation/base_piecewise_backend.py | 115 ------- vllm/compilation/base_static_graph.py | 57 ++++ vllm/compilation/cuda_graph.py | 180 ++++++++++ vllm/compilation/cuda_piecewise_backend.py | 362 --------------------- vllm/compilation/piecewise_backend.py | 179 ++++++++++ vllm/config.py | 109 ++++--- vllm/forward_context.py | 18 +- vllm/platforms/cuda.py | 8 +- vllm/platforms/interface.py | 13 +- vllm/platforms/rocm.py | 4 +- vllm/v1/attention/backends/flash_attn.py | 10 +- vllm/v1/attention/backends/flashinfer.py | 12 +- vllm/v1/attention/backends/mla/flashmla.py | 5 +- vllm/v1/attention/backends/triton_attn.py | 6 +- vllm/v1/attention/backends/utils.py | 34 +- vllm/v1/worker/gpu_model_runner.py | 264 +++++++++++---- 17 files changed, 753 insertions(+), 667 deletions(-) delete mode 100644 vllm/compilation/base_piecewise_backend.py create mode 100644 vllm/compilation/base_static_graph.py create mode 100644 vllm/compilation/cuda_graph.py delete mode 100644 vllm/compilation/cuda_piecewise_backend.py create mode 100644 vllm/compilation/piecewise_backend.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 542869687ab..8235a7e9d34 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -18,7 +18,7 @@ from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from vllm.utils import is_torch_equal_or_newer from .compiler_interface import (CompilerInterface, EagerAdaptor, InductorAdaptor, InductorStandaloneAdaptor) @@ -258,6 +258,12 @@ def split_graph(graph: fx.GraphModule, # we share the global graph pool among all the backends global_graph_pool = None +def get_global_graph_pool(): + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = current_platform.graph_pool_handle() + return global_graph_pool + compilation_start_time = 0.0 @@ -317,10 +323,9 @@ def call_module(self, target: torch.fx.node.Target, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None) - - piecewise_backend = resolve_obj_by_qualname( - current_platform.get_piecewise_backend_cls()) - self.module.__dict__[target] = piecewise_backend( + # Lazy import here to avoid circular import + from .piecewise_backend import PiecewiseBackend + self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) @@ -391,9 +396,8 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() + + global_graph_pool = get_global_graph_pool() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -563,6 +567,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: + return self.split_gm + # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode fake_mode = detect_fake_mode() @@ -581,18 +589,14 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: any(is_symbolic(d) for d in x.size()) ] - if self.compilation_config.full_cuda_graph: - assert self.compilation_config.use_cudagraph, \ - "full_cuda_graph mode requires use_cudagraph to be True" - fullgraph_wrapper = resolve_obj_by_qualname( - current_platform.get_fullgraph_wrapper_cls()) - self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, - self.graph_pool, - self.sym_tensor_indices) - - if not self.compilation_config.use_cudagraph or \ - not self.compilation_config.cudagraph_copy_inputs: - return self.split_gm + # if self.compilation_config.full_cuda_graph: + # assert self.compilation_config.use_cudagraph, \ + # "full_cuda_graph mode requires use_cudagraph to be True" + # fullgraph_wrapper = resolve_obj_by_qualname( + # current_platform.get_fullgraph_wrapper_cls()) + # self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, + # self.graph_pool, + # self.sym_tensor_indices) # compiler managed cudagraph input buffers # we assume the first run with symbolic shapes diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py deleted file mode 100644 index 854c9146543..00000000000 --- a/vllm/compilation/base_piecewise_backend.py +++ /dev/null @@ -1,115 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Protocol - -import torch.fx as fx - -from vllm.compilation.backends import VllmBackend -from vllm.config import VllmConfig - - -class AbstractPiecewiseBackend(Protocol): - """ - PiecewiseBackend interface that allows platforms to extend - piecewise static graph. - """ - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, **kwargs): - """ - Initializes the PiecewiseBackend class with compilation and - execution-related configurations. - - This class handles piecewise compilation, graph capturing, - and dispatching for specific input shapes. - - Args: - graph (fx.GraphModule): The graph represented in fx. - vllm_config (VllmConfig): Global configuration for vLLM. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. - piecewise_compile_index (int): - Index of the current piecewise subgraph. - total_piecewise_compiles (int): - Total number of piecewise-compiled graphs. - sym_shape_indices (list[int]): - Indices of symbolic shape. - compiled_graph_for_general_shape (Callable): - Callable that executes the graph compiled for general shapes. - vllm_backend (VllmBackend): - Backend compiler that manages compilation and graph runtime - for vLLM. - - Keyword Args: - kwargs: Additional keyword arguments reserved for future - extensions or custom platforms. - """ - raise NotImplementedError - - def __call__(self, *args) -> Any: - """Executes the compiled graph for given input args. - - If this is the first invocation, executes the general compiled graph - and initiates the compilation process tracking. For subsequent calls, - dynamically dispatches execution to either a compiled graph or a static - graph based on the input shape. - - Args: - *args: Variable length input arguments to be passed into the - graph. The symbolic shape is expected to be in position - `sym_shape_indices[0]`. - - Returns: - Any: Output of the executed graph. This can be from the general - compiled graph, a specialized compiled version for the given shape, - or a replayed static graph. - """ - raise NotImplementedError - - -class AbstractFullgraphWrapper(Protocol): - """ - FullgraphWrapper interface that allows platforms to wrap the piecewise graph - to be viewed or captured as a full graph. - """ - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, sym_shape_indices: list[int], **kwargs): - """ - Initializes the FullgraphWrapper class with compilation and - execution-related configurations. - - Args: - graph (fx.GraphModule): The graph represented in fx. - vllm_config (VllmConfig): Global configuration for vLLM. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. - sym_shape_indices (list[int]): - Indices of symbolic shape. - - Keyword Args: - kwargs: Additional keyword arguments reserved for future - extensions or custom platforms. - - """ - raise NotImplementedError - - def __call__(self, *args) -> Any: - """ - Executes the wrapped graph for given input args. - - Args: - *args: Variable length input arguments to be passed into the - graph. The symbolic shape is expected to be in position - `sym_shape_indices[0]`. - - Returns: - Any: Output of the executed wrapped graph. - """ - raise NotImplementedError diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py new file mode 100644 index 00000000000..7c95b6eedf4 --- /dev/null +++ b/vllm/compilation/base_static_graph.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +from vllm.config import VllmConfig + + +class AbstractStaticGraphWrapper(Protocol): + """ + StaticGraphWrapper interface that allows platforms to wrap a callable + to be captured as a static graph. + """ + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + graph_pool: Any, runtime_style: Any, **kwargs): + """ + Initializes the StaticGraphWrapper class with graph capturing and + execution-related configurations. + + Args: + runnable (Callable): The callable to be wrapped and captured. + vllm_config (VllmConfig): Global configuration for vLLM. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + runtime_style (Any): The style of the static + graph runtime. + Keyword Args: + kwargs: Additional keyword arguments for platform-specific + configurations. + """ + raise NotImplementedError + + def maybe_replace_runnable(self, shape: int, runnable: Any): + """ + Replaces the runnable with a new one for a specific compiled shape. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> Any: + """ + Executes the wrapped callable. + + This may involve replaying a captured static graph if the conditions + are met, or running the original callable eagerly and potentially + capturing it. + + Args: + *args: Variable length input arguments to be passed into the + callable. + **kwargs: Keyword arguments to be passed into the callable. + + Returns: + Any: Output of the executed callable. + """ + raise NotImplementedError diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py new file mode 100644 index 00000000000..45d39cc6653 --- /dev/null +++ b/vllm/compilation/cuda_graph.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.config import VllmConfig, CUDAGraphRuntimeStyle +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class CUDAGraphEntry: + runtime_shape: int + num_finished_warmup: int = 0 + runnable: Callable = None # type: ignore + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + usage_type: Optional[str] = None # For debug logging only + + +class CUDAGraphWrapper: + """ + This class simply wrap a runnable for cudagraph functionality, + taking responsibility of capturing cudagraph and running the replay. + """ + + def __init__(self, runnable: Any, vllm_config: VllmConfig, graph_pool: Any, + runtime_style: CUDAGraphRuntimeStyle, + cudagraph_specific_config: dict[str, Any]={}): + self.runnable = runnable + self.vllm_config = vllm_config + self.graph_pool = graph_pool + self.runtime_style = runtime_style + self.compilation_config = vllm_config.compilation_config + + self.first_run_finished = False + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + assert self.runtime_style >= CUDAGraphRuntimeStyle.PIECEWISE + assert graph_pool is not None + self.debug_capturing = cudagraph_specific_config.get( + "debug_capturing", True) + self.gc_disable = cudagraph_specific_config.get( + "gc_disable", False) + self.weak_ref_output = cudagraph_specific_config.get( + "weak_ref_output", True) + usage_type = cudagraph_specific_config.get("usage_type", None) + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) + # the entries for different shapes that we need to capture cudagraph + self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {} + + for shape in self.cudagraph_capture_sizes: + + self.concrete_cudagraph_entries[shape] = CUDAGraphEntry( + runtime_shape=shape, + runnable=self.runnable, + usage_type=usage_type, # for debug logging only + ) + + def maybe_replace_runnable(self, shape: int, runnable: Callable): + # this is a hack to replace a general shape runnable with a compiled + # runnable of a specific shape. + if shape not in self.concrete_cudagraph_entries: + return + entry = self.concrete_cudagraph_entries[shape] + assert entry.cudagraph is None, "Cudagraph is already captured" + entry.runnable = runnable + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + runtime_shape = forward_context.num_tokens + cudagraph_runtime_style = forward_context.cudagraph_runtime_style + + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\ + runtime_shape is None: + # TODO: make sure here is on profile running or eager running + return self.runnable(*args, **kwargs) + if cudagraph_runtime_style != self.runtime_style: + # CUDAGraph runtime style don't match the current + # configuration, so directly call runnable eagerly + # as it's always safe. + return self.runnable(*args, **kwargs) + + if runtime_shape not in self.concrete_cudagraph_entries: + # we don't need to do anything for this shape. + return self.runnable(*args, **kwargs) + + entry = self.concrete_cudagraph_entries[runtime_shape] + + + if entry.cudagraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.debug_capturing: + logger.debug( + "Warming up %s/%s of %s usage for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + entry.usage_type, entry.runtime_shape) + return entry.runnable(*args, **kwargs) + + if self.debug_capturing: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. We only log it in the debug mode. + logger.debug( + "Capturing a cudagraph of %s usage for shape %s", + entry.usage_type, entry.runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if self.gc_disable: + # during every model forward for piecewise cudagraph + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again + # across layers will make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args, **kwargs) + if self.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last + # graph will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during " + f"replay. Expected {entry.input_addresses}, got " + f"{new_input_addresses}") + + entry.cudagraph.replay() + return entry.output \ No newline at end of file diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py deleted file mode 100644 index 4feeeea9aa9..00000000000 --- a/vllm/compilation/cuda_piecewise_backend.py +++ /dev/null @@ -1,362 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from contextlib import ExitStack -from typing import Any, Callable, Optional -from unittest.mock import patch - -import torch -import torch.fx as fx - -import vllm.envs as envs -from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter -from vllm.compilation.monitor import end_monitoring_torch_compile -from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context -from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors - -logger = init_logger(__name__) - - -@dataclasses.dataclass -class ConcreteSizeEntry: - runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - - compiled: bool = False - runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - cudagraph_runnable: Optional[Callable] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None - - usage_type: Optional[str] = None - - -class CUDAGraphWrapper: - """ - A wrapper class for cudagraphs functionality. - - This class creates a cudagraph runnable for a given `ConcreteSizeEntry`, - taking responsibility of capturing cudagraph and running the replay. - """ - - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - def create_cudagraph_runnable( - self, entry: ConcreteSizeEntry, - cudagraph_runnable_config: dict[str, Any]) -> Any: - graph_pool = cudagraph_runnable_config["graph_pool"] - assert graph_pool is not None - debug_capturing = cudagraph_runnable_config.get( - "debug_capturing", True) - gc_disable = cudagraph_runnable_config.get("gc_disable", False) - weak_ref_output = cudagraph_runnable_config.get( - "weak_ref_output", True) - - def cudagraph_runnable(*args): - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if debug_capturing: - logger.debug( - "Warming up %s/%s of %s usage for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, entry.runtime_shape) - return entry.runnable(*args) - - if debug_capturing: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every - # shape. We only log it in the debug mode. - logger.debug( - "Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, entry.runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if gc_disable: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if weak_ref_output: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last - # graph will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during " - f"replay. Expected {entry.input_addresses}, got " - f"{new_input_addresses}") - - entry.cudagraph.replay() - return entry.output - - return cudagraph_runnable - - -class CUDAPiecewiseBackend: - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend): - """ - The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. - - We will compile `self.graph` once for the general shape, - and then compile for different shapes specified in - `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. - """ - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool - self.piecewise_compile_index = piecewise_compile_index - self.total_piecewise_compiles = total_piecewise_compiles - self.vllm_backend = vllm_backend - - self.is_first_graph = piecewise_compile_index == 0 - self.is_last_graph = ( - piecewise_compile_index == total_piecewise_compiles - 1) - - self.is_full_graph = total_piecewise_compiles == 1 - - self.compile_sizes: set[int] = set( - self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() - - self.first_run_finished = False - - self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa - - self.sym_shape_indices = sym_shape_indices - - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - # the entries for different shapes that we need to either - # compile or capture cudagraph - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} - - # to_be_compiled_sizes tracks the remaining sizes to compile, - # and updates during the compilation process, so we need to copy it - self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): - usage_type = "full/general" if self.is_full_graph else \ - "piecewise/general" - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, - usage_type=usage_type, # for debug logging only - ) - - self.cudagraph_wrapper = CUDAGraphWrapper(vllm_config) - - def check_for_ending_compilation(self): - if self.is_last_graph and not self.to_be_compiled_sizes: - # no specific sizes to compile - # save the hash of the inductor graph for the next run - self.vllm_backend.compiler_manager.save_to_file() - end_monitoring_torch_compile(self.vllm_config) - - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - self.check_for_ending_compilation() - return self.compiled_graph_for_general_shape(*args) - - runtime_shape = args[self.sym_shape_indices[0]] - if runtime_shape not in self.concrete_size_entries: - # we don't need to do anything for this shape - return self.compiled_graph_for_general_shape(*args) - - entry = self.concrete_size_entries[runtime_shape] - - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: - entry.compiled = True - self.to_be_compiled_sizes.remove(runtime_shape) - # args are real arguments - entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args, - self.compilation_config.inductor_compile_config, - self.compilation_config, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape) - - # finished compilations for all required shapes - if self.is_last_graph and not self.to_be_compiled_sizes: - self.check_for_ending_compilation() - - # Skip CUDA graphs if this entry doesn't use them OR - # if we're supposed to treat the piecewise graphs as a whole, - # In the latter case, forward_context.skip_attention_cuda_graphs - # is False, and we rely on a wrapper class to capture the full - # cudagraph outside the fx graph. - skip_attention_cuda_graphs = get_forward_context( - ).skip_attention_cuda_graphs - if not entry.use_cudagraph or not skip_attention_cuda_graphs: - return entry.runnable(*args) - - if entry.cudagraph_runnable is None: - cudagraph_runnable_config = { - "graph_pool": self.graph_pool, - "debug_capturing": self.is_first_graph, - "gc_disable": not self.is_first_graph, - "weak_ref_output": self.is_last_graph, - } - entry.cudagraph_runnable = \ - self.cudagraph_wrapper.create_cudagraph_runnable(entry, - cudagraph_runnable_config) - return entry.cudagraph_runnable(*args) - - -class FullCudagraphWrapper: - - def __init__( - self, - graph: fx.GraphModule, - vllm_config: VllmConfig, - graph_pool: Any, - sym_shape_indices: list[int], - ): - self.graph = graph - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool - self.sym_shape_indices = sym_shape_indices - - self.separate_attention_routine = ( - vllm_config.compilation_config.separate_attention_routine) - - self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - - self.first_run_finished = False - - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() - - self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} - self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {} - - for shape in self.cudagraph_capture_sizes: - self.concrete_size_entries[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=False, - use_cudagraph=True, - usage_type="general", - ) - if self.separate_attention_routine: - self.concrete_size_entries_decode[shape] = ConcreteSizeEntry( - runtime_shape=shape, - need_to_compile=False, - use_cudagraph=True, - usage_type="decode", - ) - - self.cudagraph_wrapper = CUDAGraphWrapper(vllm_config) - - def __call__(self, *args) -> Any: - if not self.first_run_finished: - self.first_run_finished = True - return self.graph(*args) - list_args = list(args) - runtime_shape = list_args[self.sym_shape_indices[0]].shape[0] - forward_context = get_forward_context() - - if forward_context.skip_attention_cuda_graphs: - # turn back to piecewise cudagraphs backend, which is responsible - # for capturing and running the piecewise cudagraphs. - return self.graph(*args) - - # if not skip, the fx graph and its sub-graphs will only be supposed to - # eagerly run the compiled graphs, which should be cudagraph capturable - # as a whole. - - concrete_size_entries = self.concrete_size_entries - if self.separate_attention_routine and forward_context.is_pure_decode: - concrete_size_entries = self.concrete_size_entries_decode - - if runtime_shape not in concrete_size_entries: - # we don't need to do anything for this shape. - return self.graph(*args) - - entry = concrete_size_entries[runtime_shape] - - if entry.runnable is None: - entry.runnable = self.graph - - if not entry.use_cudagraph: - return entry.runnable(*args) - - if entry.cudagraph_runnable is None: - cudagraph_runnable_config = {"graph_pool": self.graph_pool} - entry.cudagraph_runnable = \ - self.cudagraph_wrapper.create_cudagraph_runnable(entry, - cudagraph_runnable_config) - - return entry.cudagraph_runnable(*args) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py new file mode 100644 index 00000000000..a4afc854306 --- /dev/null +++ b/vllm/compilation/piecewise_backend.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from typing import Any, Callable, Optional + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig, CUDAGraphRuntimeStyle +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname +logger = init_logger(__name__) + + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes + compiled: bool = False + runnable: Callable = None # type: ignore + + usage_type: Optional[str] = None # For debug logging only + + +class PiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and cudagraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, the static graph capturing (e.g. CUDA graph) is handled + by a separate static graph wrapper, which is expected to wrap the + compiled callable of the general shape. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.is_full_graph = total_piecewise_compiles == 1 + + self.compile_sizes: set[int] = set( + self.compilation_config.compile_sizes) + + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture cudagraph + self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() + + usage_type = "full/general" if self.is_full_graph else \ + "piecewise/general" + + self.cudagraph_capture_sizes: set[int] = set() + self.cudagraph_runable: Optional[CUDAGraphWrapper] = None + if self.compilation_config.cudagraph_mode > 0: + cudagraph_specific_config = { + "debug_capturing": self.is_first_graph, + "gc_disable": not self.is_first_graph, + "weak_ref_output": self.is_last_graph, + "usage_type" : usage_type } + + # Note: To easier distinguish whether it is under the + # piecewise backend, we always assume CUDAGraphRuntimeStyle.PIECEWISE + # here, no matter it is on a full fx graph or piecewise fx graph. + + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + self.cudagraph_runable = static_graph_wrapper_class( + self.compiled_graph_for_general_shape, + vllm_config, + self.graph_pool, + runtime_style = CUDAGraphRuntimeStyle.PIECEWISE, + cudagraph_specific_config = cudagraph_specific_config) + + self.cudagraph_capture_sizes = (self.compilation_config.\ + cudagraph_capture_sizes) + + + # We now only keep compilation management inside this class directly. + # The cudagraph logic is delegated to the CUDAGraphWrapper class. + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, + runnable=self.compiled_graph_for_general_shape, + usage_type=usage_type, # for debug logging only + ) + + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if self.is_debugging_mode: + assert runtime_shape==get_forward_context().num_tokens + + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # replace the runnable with the compiled one for + # cudagraph capturing + if self.cudagraph_runable is not None: + self.cudagraph_runable.maybe_replace_runnable(runtime_shape, + entry.runnable) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if not entry.use_cudagraph: + return entry.runnable(*args) + + # safety check to ensure the cudagraph runnable is not None + assert self.cudagraph_runable is not None + return self.cudagraph_runable(*args) + diff --git a/vllm/config.py b/vllm/config.py index d4e7ddcc678..2388f5afbbe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3849,6 +3849,17 @@ class CompilationLevel: DYNAMO_ONCE = 2 PIECEWISE = 3 +class CUDAGraphMode: + # constants for the config of the cudagraph mode + NONE = 0 + PIECEWISE = 1 + FULL = 2 + +class CUDAGraphRuntimeStyle: + # constants same as CUDAGraphMode, but used for runtime dispatching + NONE = 0 + PIECEWISE = 1 + FULL = 2 @config @dataclass @@ -3913,14 +3924,13 @@ class CompilationConfig: - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - CudaGraph capture: - - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] - [`cudagraph_num_of_warmups`] [vllm.config.CompilationConfig.cudagraph_num_of_warmups] - [`cudagraph_copy_inputs`] [vllm.config.CompilationConfig.cudagraph_copy_inputs] - - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] @@ -3940,7 +3950,7 @@ class CompilationConfig: certain small batchsizes, where inductor is good at optimizing. """ # Top-level Compilation control - level: int = 0 + level: int = -1 # -1 for no user-setting, VllmConfig.__post_init__ will handle it # noqa """The level of compilation: - 0: no compilation. @@ -3978,7 +3988,7 @@ class CompilationConfig: By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" - splitting_ops: list[str] = field(default_factory=list) + splitting_ops: Optional[list[str]] = None """A list of ops to split the full graph into subgraphs, used in piecewise compilation.""" @@ -4008,6 +4018,33 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation + cudagraph_mode: int = field(default_factory= lambda: + 1 if envs.VLLM_USE_V1 else 0) + """ + The mode of the cudagraph. + - 0: NONE, no cudagraph capture. + - 1: PIECEWISE. (v1 default) + - 2: FULL. + For cudagraph_mode > 0, It requires that all input buffers have + fixed addresses and all splitting ops write their outputs to + input buffers. + + PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph + incompatiable ops (i.e. some attention ops) outside the cudagraph + for general flexibility. + + FULL mode instead try to capture full cudagraph for fully compatible + routines (i.e. most attention backend support pure decode batches), + and may fall back to piecewise cudagraph for partially imcompatible + routines. This may provide performance benefits for smaller models. + + Currently, the cudagraph mode is only used for the v1 engine. + Note that the cudagraph logic is generally orthogonal to the + compilation logic. For piecewise cudagraph, the logic is kept + inside the compilation. Meanwhile, the full cudagraph is captured + outside the compilation, and in future it will further supports + cudagraph without compilation. + """ use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) """Whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. @@ -4018,8 +4055,9 @@ class CompilationConfig: CompilationLevel.PIECEWISE (aka -O3). Note that this is orthogonal to the cudagraph capture logic outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future.""" + TODO: Now this flag is treated as a placeholder for cudagraph + control inside compilation, will removed it in future. + """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -4035,11 +4073,6 @@ class CompilationConfig: are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an internally managed buffer. Default is False.""" - full_cuda_graph: bool = False - """whether to use a full cuda graph for the entire forward pass rather than - splitting certain operations such as attention into subgraphs. Thus this - flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models.""" separate_attention_routine: bool = False """ Enable a distinct attention calls routine under an attention backend for @@ -4048,15 +4081,6 @@ class CompilationConfig: prefill-decode and pure decode cases. This flag enables us to potentially capture the cudagraph separately for each branch. """ - force_no_split_graph: bool = False - """ - Enforce the fx graph to be a flattened full graph instead of a piecewise fx - graph with submodules (splited by attention ops, as the default behavior). - - Maintaining the full graph may offer benefits in some cases, e.g., enabling - attention-related custom inductor passes, such as attention+quant fusion. - """ - pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -4255,30 +4279,27 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called if self.separate_attention_routine: - assert self.full_cuda_graph, ( + assert self.cudagraph_mode==CUDAGraphMode.FULL, ( "separate_attention_routine requires " - "full_cuda_graph to be True") - - if self.force_no_split_graph: - assert self.full_cuda_graph, ( - "force_no_split_graph requires full_cuda_graph to be True") - assert not self.splitting_ops, ( - "force_no_split_graph cannot be used together with " - "splitting_ops is not empty. Please set splitting_ops" - "to [] if you want to use force_no_split_graph.") - self.splitting_ops = [] - else: - # NOTE: When full_cuda_graph is True, instead of setting an empty + "cudagraph_mode be CUDAGraphMode.FULL") + + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty # list and capture the full cudagraph inside the flattened fx # graph, we keep the piecewise fx graph structure but capture the # full cudagraph outside the fx graph. This reduces some cpu # overhead when the runtime batch_size is not cudagraph captured. - # see PR #20059. - if not self.splitting_ops: - self.splitting_ops = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ] + # see https://github.com/vllm-project/vllm/pull/20059 for details. + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + elif len(self.splitting_ops)==0: + assert self.cudagraph_mode==CUDAGraphMode.FULL, ( + "Seting splitting_ops as empty list requires " + "cudagraph_mode be CUDAGraphMode.FULL") + + self.splitting_ops = [] @config @@ -4550,7 +4571,8 @@ def __post_init__(self): # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph # is set to True, full CUDA graphs will be used. self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.level = CompilationLevel.PIECEWISE + if self.compilation_config.level == -1: + self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() self._set_cudagraph_sizes() @@ -4571,10 +4593,11 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.compilation_config.full_cuda_graph and \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL and \ not self.model_config.disable_cascade_attn: - logger.info("full_cuda_graph is not supported with " - "cascade attention. Disabling cascade attention.") + logger.info("CUDAGraphMode.FULL is not supported with " + "cascade attention currently. Disabling cascade" + "attention.") self.model_config.disable_cascade_attn = True disable_chunked_prefill_reasons: list[str] = [] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 4db26b6d49f..50ebdc27839 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,7 +11,7 @@ import torch.distributed as dist import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import ParallelConfig, VllmConfig, CUDAGraphRuntimeStyle from vllm.logger import init_logger if TYPE_CHECKING: @@ -92,13 +92,12 @@ class ForwardContext: attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + num_tokens: Optional[int] = None # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - # determine whether to use a full cudagraph for attention or piecewise - # cudagraphs that skip the attention part. By default true, we use piecewise - # cudagraphs. - skip_attention_cuda_graphs: bool = True - is_pure_decode: bool = False + # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. + # by default NONE, no cudagraph is used. + cudagraph_runtime_style: int = CUDAGraphRuntimeStyle.NONE _forward_context: Optional[ForwardContext] = None @@ -119,8 +118,7 @@ def set_forward_context( virtual_engine: int = 0, num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_attention_cuda_graphs: bool = True, - is_pure_decode: bool = False, + cudagraph_runtime_style: int = CUDAGraphRuntimeStyle.NONE, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -142,11 +140,11 @@ def set_forward_context( _forward_context = ForwardContext( no_compile_layers=vllm_config.compilation_config. static_forward_context, + num_tokens=num_tokens, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, - skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decode=is_pure_decode, + cudagraph_runtime_style=cudagraph_runtime_style, ) try: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index eb329c342a2..1d4593668a0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -368,12 +368,8 @@ def use_custom_allreduce(cls) -> bool: return True @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa - - @classmethod - def get_fullgraph_wrapper_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index af3f4a50812..5c624f6d710 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -521,18 +521,11 @@ def get_cu_count(cls, device_id: int = 0) -> int: raise NotImplementedError @classmethod - def get_piecewise_backend_cls(cls) -> str: + def get_static_graph_wrapper_cls(cls) -> str: """ - Get piecewise backend class for piecewise graph. + Get static graph wrapper class for static graph. """ - return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa - - @classmethod - def get_fullgraph_wrapper_cls(cls) -> str: - """ - Get fullgraph wrapper class for fullgraph static graph. - """ - return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa + return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ee53a76ceb6..ff63c19c535 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -414,8 +414,8 @@ def is_navi(cls) -> bool: return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8312961c693..d38fab7fcd2 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -27,7 +27,7 @@ from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, AttentionCGSupport) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -141,17 +141,17 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.ALWAYS # FlashAttn support a unified varlen fwd kernel for prefill-decode phase, so # it's ok to either separate attention routine or not for both FA2 or 3. - force_separate_routine: ClassVar[Optional[bool]] = None + # TODO: change the default preference if needed. + prefer_separate_routine: ClassVar[Optional[bool]] = None support_full_cudagraph_only: ClassVar[bool] = True def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): model_config = runner.model_config - compilation_config = runner.vllm_config.compilation_config self.runner = runner self.num_heads_q = model_config.get_num_attention_heads( @@ -164,7 +164,7 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.block_table = block_table self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = compilation_config.full_cuda_graph + self.use_full_cuda_graph = self.runner.full_cuda_graph if self.use_full_cuda_graph: # NOTE(lucas): AOT scheduling not supported in full cuda graph mode diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index bffe774b38a..21eb37011c7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,12 +15,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config, CUDAGraphMode from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, - get_kv_cache_layout) + get_kv_cache_layout, + AttentionCGSupport) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -218,8 +219,8 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - full_cudagraph_supported: ClassVar[bool] = True - force_separate_routine: ClassVar[Optional[bool]] = True + attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.PURE_DECODE_ONLY + prefer_separate_routine: ClassVar[Optional[bool]] = True def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -228,8 +229,7 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - self.enable_cuda_graph = ( - self.vllm_config.compilation_config.full_cuda_graph) + self.enable_cuda_graph = self.runner.full_cuda_graph if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d49a87c8a57..e4f06c47af7 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -17,6 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -54,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # Decode-only - force_separate_routine: ClassVar[Optional[bool]] = True + attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.PURE_DECODE_ONLY + prefer_separate_routine: ClassVar[Optional[bool]] = True def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ef856257c38..d37f671dcb4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,7 +19,7 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches) + make_local_attention_virtual_batches, AttentionCGSupport) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -73,9 +73,7 @@ class LocalAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True - - support_full_cudagraph_only: ClassVar[bool] = True + attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.ALWAYS def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 5750070d018..85333312113 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -46,19 +46,33 @@ class CommonAttentionMetadata: M = TypeVar("M") +class AttentionCGSupport: + # constants for the cudagraph support of the attention backend + + ALWAYS = 2 # Cudagraph always supported + # Cudagraph supported for pure decode, need to use piecewise + # if mixed prefill-decode batches + PURE_DECODE_ONLY = 1 + NEVER = 0 # No support + class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. - full_cudagraph_supported: ClassVar[bool] = False - # If full cudagraph support, select if this attention backend - # enforce separate rountine to be True, False or None (free). - force_separate_routine: ClassVar[Optional[bool]] = None - - # If the attention backend supports full cudagraph only. Which means, - # it supports full cudagraph for both prefill-decode phase and pure - # decode phase. Used to check if compilation_config.force_no_split_graph - # = True is valid when full cudagraph is enable. - support_full_cudagraph_only: ClassVar[bool] = False + attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.NEVER + # If attn_cudagraph_supported >0, attention backend can set its + # preference of separate rountine to be True, False or None. + # True: expect to explicit separate routines for capturing cudagraph + # of pure decode batches and mixed batches. Should be true if + # attn_cudagraph_supported is PURE_DECODE_ONLY. And may be faster + # to set it true if attn_cudagraph_supported is ALWAYS. + # False: expect to keep a unified kernel routine when + # attn_cudagraph_supported is ALWAYS. It is the case if an + # attention kernel can dynamically dispatch different optimzied + # rountines inside a kernel, so no need to manually separate them + # outside kernel when capturing cudagraph. + # None: indicates no specific preference, and the control is left + # to the users. + prefer_separate_routine: ClassVar[Optional[bool]] = None @abstractmethod def build(self, common_prefix_len: int, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 89021a950d9..17717b968dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -19,8 +19,11 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.backends import get_global_graph_pool from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) + get_layers_from_vllm_config, + CUDAGraphMode, CUDAGraphRuntimeStyle) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -48,7 +51,8 @@ is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + AttentionCGSupport) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, @@ -86,6 +90,8 @@ logger = init_logger(__name__) +# constant code pure decode +DECODE_BOOLEN = True class GPUModelRunner(LoRAModelRunnerMixin): @@ -214,11 +220,10 @@ def __init__( block_sizes=[self.cache_config.block_size], ) - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and self.vllm_config.compilation_config.use_cudagraph - and not self.model_config.enforce_eager) + self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.use_cuda_graph = (self.cudagraph_mode > CUDAGraphMode.NONE + and not self.model_config.enforce_eager) + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -226,7 +231,7 @@ def __init__( self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) - self.full_cuda_graph = self.compilation_config.full_cuda_graph + self.full_cuda_graph = self.cudagraph_mode == CUDAGraphMode.FULL # Cache the device properties. self._init_device_properties() @@ -316,6 +321,14 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + # Dict to store cudagraph candidates for later runtime dispatching. + self.cudagraph_candidates: dict[tuple, Any] = {} + # if we want to only capture pure decode batches + self.skip_capture_general_batches = False + + self.no_compilation = self.compilation_config.level != \ + CompilationLevel.PIECEWISE or self.model_config.enforce_eager + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -1356,14 +1369,11 @@ def execute_model( else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, we skip them, - # and turn back to the piecewise CUDA graphs. Or if full_cuda_graph is - # False, we always turn to the piecewise CUDA graphs. - skip_attention_cuda_graphs = not self.full_cuda_graph or not attention_cuda_graphs # noqa: E501 - # Note: When skip_attention_cuda_graphs is always False and - # compilition_config.separate_attention_routine is True, as in FA2, + + cudagraph_runtime_style = self._cudagraph_runtime_style( + attention_cuda_graphs) + # Note: When cudagraph_mode is FULL and + # compilation_config.separate_attention_routine is True, as in FA2, # this flag helps to determine the correct routine for the full # cudagraph. is_pure_decode = num_scheduled_tokens == self.input_batch.num_reqs @@ -1375,9 +1385,9 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decode=is_pure_decode, - ): + cudagraph_runtime_style=cudagraph_runtime_style),\ + self.cudagraph_dispatch(cudagraph_runtime_style, + is_pure_decode): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1800,6 +1810,11 @@ def load_model(self) -> None: self.device, self.parallel_config, ) + # Immediately add self.model to cudagraph_candidates + # for profile run. + # Note that self.model always support no cudagraph. + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.NONE,): self.model}) def save_tensorized_model( self, @@ -1987,7 +2002,9 @@ def _dummy_run( max_query_len = 1 if is_pure_decode else num_tokens attn_metadata: Optional[dict[str, Any]] = None - skip_attention_cuda_graphs = True + cudagraph_runtime_style = CUDAGraphRuntimeStyle.PIECEWISE if \ + not self.no_compilation else CUDAGraphRuntimeStyle.NONE + if capture_attn_cudagraph: # Note: At this step, `capture_attn_cudagraph` should be True or # "auto", but we always treat it as "auto". i.e., always let the @@ -2015,10 +2032,10 @@ def _dummy_run( attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) - skip_attention_cuda_graphs = not attention_cuda_graphs \ - if self.full_cuda_graph else True - - if not skip_attention_cuda_graphs: + cudagraph_runtime_style = self._cudagraph_runtime_style( + attention_cuda_graphs) + + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -2029,10 +2046,14 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i else: attn_metadata = None # reset to None other than empty dict + + if is_profile: + # when profiling, _maybe_initialize_cudagraph() is not called, + # so always run no cudagraph. + cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2062,9 +2083,10 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_attention_cuda_graphs=skip_attention_cuda_graphs, - is_pure_decode=is_pure_decode): - outputs = model( + cudagraph_runtime_style=cudagraph_runtime_style), \ + self.cudagraph_dispatch( + cudagraph_runtime_style, is_pure_decode): + outputs = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -2315,43 +2337,45 @@ def capture_model(self) -> None: with graph_capture(device=self.device): full_cg = self.full_cuda_graph - # If full_cuda_graph is true, automatically determine whether or - # not to capture the attention for the mix prefill-decode (general) - # phase, based on the attention backends. - capture_attn_cudagraph_general = "auto" if full_cg else False - - # Skip capturing batch sizes of 1 in mix prefill-decode if - # separate_attention_routine is on. As bs=1 can treat as a - # pure decode. - start_idx = 0 - if self.vllm_config.compilation_config.separate_attention_routine \ - and len(self.cudagraph_batch_sizes) > 0 \ - and self.cudagraph_batch_sizes[0] == 1: - start_idx = 1 - - # We skip EPLB here since we don't want to record dummy metrics - - # Only rank 0 should print progress bar during capture - compilation_cases = reversed(self.cudagraph_batch_sizes[start_idx:]) - if is_global_first_rank(): - compilation_cases = tqdm(list(compilation_cases), - desc="Capturing CUDA graphs (mix prefill-decode)") - # Capture the mix prefill-decode (general usage) cudagraphs - for num_tokens in compilation_cases: - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): + if not self.skip_capture_general_batches: + # If full_cuda_graph is true, automatically determine whether + # or not to capture the attention for the mix prefill-decode + # phase, based on the attention backends. + capture_attn_cg_general = "auto" if full_cg else False + + # Skip capturing batch sizes of 1 in mix prefill-decode if + # separate_attention_routine is on. As bs=1 can treat as a + # pure decode. + start_idx = 0 + if self.compilation_config.separate_attention_routine \ + and len(self.cudagraph_batch_sizes) > 0 \ + and self.cudagraph_batch_sizes[0] == 1: + start_idx = 1 + + # We skip EPLB here since we don't want to record dummy metrics + + # Only rank 0 should print progress bar during capture + compilation_cases = reversed(self.cudagraph_batch_sizes[ + start_idx:]) + if is_global_first_rank(): + compilation_cases = tqdm(list(compilation_cases), + desc="Capturing CUDA graphs (mix prefill-decode)") + # Capture the mix prefill-decode (general usage) cudagraphs + for num_tokens in compilation_cases: + for _ in range( + self.compilation_config.cudagraph_num_of_warmups): + self._dummy_run( + num_tokens, + capture_attn_cudagraph=capture_attn_cg_general, + is_pure_decode=False, + skip_eplb=True) self._dummy_run( num_tokens, - capture_attn_cudagraph=capture_attn_cudagraph_general, + capture_attn_cudagraph=capture_attn_cg_general, is_pure_decode=False, skip_eplb=True) - self._dummy_run( - num_tokens, - capture_attn_cudagraph=capture_attn_cudagraph_general, - is_pure_decode=False, - skip_eplb=True) - if self.vllm_config.compilation_config.separate_attention_routine: + if self.compilation_config.separate_attention_routine: # Capture the pure decode cudagraphs. Typically a full cudagraph max_num_reqs = self.scheduler_config.max_num_seqs @@ -2361,8 +2385,9 @@ def capture_model(self) -> None: compilation_cases_decode = reversed( decode_cudagraph_batch_sizes) if is_global_first_rank(): - compilation_cases_decode = tqdm(list(compilation_cases_decode), - desc="Capturing CUDA graphs (pure decode)") + compilation_cases_decode = tqdm(list( + compilation_cases_decode), + desc="Capturing CUDA graphs (pure decode)") for num_tokens in tqdm( reversed(decode_cudagraph_batch_sizes), @@ -2386,6 +2411,90 @@ def capture_model(self) -> None: # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + + def _maybe_initialize_cudagraph(self): + + if self.compilation_config.level == CompilationLevel.PIECEWISE\ + and len(self.compilation_config.splitting_ops)>0: + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.PIECEWISE,): self.model}) + logger.debug("Piecewise cudagraph initialized") + + if self.full_cuda_graph: + attn_cg = self.attn_metadata_builders[0].attn_cudagraph_support + # create full cudagraph for mix prefill-decode/general batches + if attn_cg == AttentionCGSupport.ALWAYS: + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): + CUDAGraphWrapper( + self.model, self.vllm_config, + get_global_graph_pool(), + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_specific_config={ + "usage_type": "general" + }) + }) + logger.debug("Full cudagraph for mixed batches initialized") + # create full cudagraph for pure decode batches + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY or \ + (attn_cg == AttentionCGSupport.ALWAYS and \ + self.compilation_config.separate_attention_routine): + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): + CUDAGraphWrapper( + self.model, self.vllm_config, + get_global_graph_pool(), + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_specific_config={ + "usage_type": "decode" + }) + }) + logger.debug("Full cudagraph for pure decode batches initialized") + + def _cudagraph_runtime_style(self, attn_cuda_graphs): + + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, we skip them, + # and turn back to the piecewise CUDA graphs. + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\ + attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE + cudagraph_runtime_style = min(self.cudagraph_mode, + cudagraph_runtime_style) + + # PIECEWISE would fall back to NONE if no compilation + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ + self.no_compilation: + cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE + + #TODO: can we optimize above logic? + return cudagraph_runtime_style + + + @contextmanager + def cudagraph_dispatch(self, cudagraph_runtime_style: int, + is_pure_decode: bool): + old_model = self.model + assert self.cudagraph_candidates, ("cudagraph_candidates are " + "not initialized.") + # select between no cudagraph and piecewise cudagraph + if cudagraph_runtime_style in [CUDAGraphRuntimeStyle.NONE, + CUDAGraphRuntimeStyle.PIECEWISE]: + self.model = self.cudagraph_candidates.get( + (cudagraph_runtime_style,), None) + else: + # for full cudagraph, select between general batches + # or pure decode batches + decode_case = (DECODE_BOOLEN,) if self.compilation_config.\ + separate_attention_routine and is_pure_decode \ + else (not DECODE_BOOLEN,) + tuple_key = (cudagraph_runtime_style,) + decode_case + self.model = self.cudagraph_candidates.get(tuple_key, None) + assert self.model is not None, ("cudagraph_candidates is not " + "correctly initialized for" + f"({cudagraph_runtime_style}, " + f"{is_pure_decode})") + yield + self.model = old_model def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -2431,40 +2540,51 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: ) if self.full_cuda_graph: - if not attn_metadata_builder_i.full_cudagraph_supported: + attn_cg: int = attn_metadata_builder_i.attn_cudagraph_support + if not attn_cg > 0: raise ValueError( f"Full CUDAGraph not supported for " f"{attn_backend_i.__name__}. Turn off " f"CompilationConfig.full_cuda_graph or use a different" f" attention backend.") - if self.compilation_config.force_no_split_graph: - assert attn_metadata_builder_i.support_full_cudagraph_only, ( # noqa: E501 + if len(self.compilation_config.splitting_ops) == 0: + assert attn_cg == AttentionCGSupport.ALWAYS, ( f"Full CUDAGraph not supported for " f"{attn_backend_i.__name__} with " - f"CompilationConfig.force_no_split_graph=True. " - f"Turn off CompilationConfig.force_no_split_graph" + f"CompilationConfig.splitting_ops = []. " + f"Set it to None (default values) " f"or use a different attention backend.") # check if the attention backends enforce to have separate # routines for mix prefill-decode and pure decode phase - if attn_metadata_builder_i.force_separate_routine is not None \ + if attn_metadata_builder_i.prefer_separate_routine is not None \ and self.compilation_config.separate_attention_routine\ - != attn_metadata_builder_i.force_separate_routine: + != attn_metadata_builder_i.prefer_separate_routine: - expected = attn_metadata_builder_i.force_separate_routine + expected = attn_metadata_builder_i.prefer_separate_routine logger.warning_once( f"Full CUDAGraph for {attn_backend_i.__name__}" - f"enforce CompilationConfig.separate_attention" + f"expect CompilationConfig.separate_attention" f"_rountine as: {expected}. Now set it to: " f"{expected}.") - self.compilation_config.separate_attention_rountine = \ + self.compilation_config.separate_attention_routine = \ expected + # for attn_cg is pure decode only, and no compilation, + # we skip capturing mix prefill-decode (general) batches. + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ + self.no_compilation: + self.skip_capture_general_batches = True self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + # Trigger cudagraph initialization here (after + # initializing attn backends). + # TODO: move this to the better place + self._maybe_initialize_cudagraph() + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ From 7d4667a7c708214d409ee49e20674bacca13c220 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 5 Jul 2025 12:00:29 +0000 Subject: [PATCH 13/33] refactors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 5 +- vllm/compilation/base_static_graph.py | 4 +- vllm/compilation/cuda_graph.py | 45 ++++++----- vllm/compilation/piecewise_backend.py | 45 +++++------ vllm/config.py | 21 ++--- vllm/forward_context.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 6 +- vllm/v1/attention/backends/flashinfer.py | 8 +- vllm/v1/attention/backends/triton_attn.py | 4 +- vllm/v1/attention/backends/utils.py | 17 ++-- vllm/v1/worker/gpu_model_runner.py | 98 ++++++++++++----------- 11 files changed, 131 insertions(+), 124 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8235a7e9d34..1e01e104b16 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -258,12 +258,14 @@ def split_graph(graph: fx.GraphModule, # we share the global graph pool among all the backends global_graph_pool = None + def get_global_graph_pool(): - global global_graph_pool + global global_graph_pool if global_graph_pool is None: global_graph_pool = current_platform.graph_pool_handle() return global_graph_pool + compilation_start_time = 0.0 @@ -396,7 +398,6 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global_graph_pool = get_global_graph_pool() # TODO: in the future, if we want to use multiple diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 7c95b6eedf4..77c7cd09e27 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol): """ def __init__(self, runnable: Callable, vllm_config: VllmConfig, - graph_pool: Any, runtime_style: Any, **kwargs): + graph_pool: Any, runtime_style: int, **kwargs): """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -31,7 +31,7 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, configurations. """ raise NotImplementedError - + def maybe_replace_runnable(self, shape: int, runnable: Any): """ Replaces the runnable with a new one for a specific compiled shape. diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 45d39cc6653..0c5e0b89592 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -10,7 +10,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter -from vllm.config import VllmConfig, CUDAGraphRuntimeStyle +from vllm.config import CUDAGraphRuntimeStyle, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import weak_ref_tensors @@ -39,9 +39,12 @@ class CUDAGraphWrapper: taking responsibility of capturing cudagraph and running the replay. """ - def __init__(self, runnable: Any, vllm_config: VllmConfig, graph_pool: Any, - runtime_style: CUDAGraphRuntimeStyle, - cudagraph_specific_config: dict[str, Any]={}): + def __init__(self, + runnable: Any, + vllm_config: VllmConfig, + graph_pool: Any, + runtime_style: int, + cudagraph_specific_config: Optional[dict[str, Any]] = None): self.runnable = runnable self.vllm_config = vllm_config self.graph_pool = graph_pool @@ -50,30 +53,30 @@ def __init__(self, runnable: Any, vllm_config: VllmConfig, graph_pool: Any, self.first_run_finished = False self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - + assert self.runtime_style >= CUDAGraphRuntimeStyle.PIECEWISE assert graph_pool is not None + if cudagraph_specific_config is None: + cudagraph_specific_config = {} self.debug_capturing = cudagraph_specific_config.get( "debug_capturing", True) - self.gc_disable = cudagraph_specific_config.get( - "gc_disable", False) + self.gc_disable = cudagraph_specific_config.get("gc_disable", False) self.weak_ref_output = cudagraph_specific_config.get( "weak_ref_output", True) - usage_type = cudagraph_specific_config.get("usage_type", None) + usage_type = cudagraph_specific_config.get("usage_type") self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) + self.compilation_config.cudagraph_capture_sizes) # the entries for different shapes that we need to capture cudagraph self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {} for shape in self.cudagraph_capture_sizes: - + self.concrete_cudagraph_entries[shape] = CUDAGraphEntry( runtime_shape=shape, runnable=self.runnable, usage_type=usage_type, # for debug logging only ) - + def maybe_replace_runnable(self, shape: int, runnable: Callable): # this is a hack to replace a general shape runnable with a compiled # runnable of a specific shape. @@ -82,7 +85,7 @@ def maybe_replace_runnable(self, shape: int, runnable: Callable): entry = self.concrete_cudagraph_entries[shape] assert entry.cudagraph is None, "Cudagraph is already captured" entry.runnable = runnable - + def __call__(self, *args, **kwargs): forward_context = get_forward_context() runtime_shape = forward_context.num_tokens @@ -94,8 +97,8 @@ def __call__(self, *args, **kwargs): return self.runnable(*args, **kwargs) if cudagraph_runtime_style != self.runtime_style: # CUDAGraph runtime style don't match the current - # configuration, so directly call runnable eagerly - # as it's always safe. + # configuration, so directly call runnable eagerly + # as it's always safe. return self.runnable(*args, **kwargs) if runtime_shape not in self.concrete_cudagraph_entries: @@ -103,7 +106,6 @@ def __call__(self, *args, **kwargs): return self.runnable(*args, **kwargs) entry = self.concrete_cudagraph_entries[runtime_shape] - if entry.cudagraph is None: if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa @@ -120,9 +122,8 @@ def __call__(self, *args, **kwargs): # Since we capture cudagraph for many different shapes and # capturing is fast, we don't need to log it for every # shape. We only log it in the debug mode. - logger.debug( - "Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, entry.runtime_shape) + logger.debug("Capturing a cudagraph of %s usage for shape %s", + entry.usage_type, entry.runtime_shape) input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) @@ -133,8 +134,8 @@ def __call__(self, *args, **kwargs): with ExitStack() as stack: if self.gc_disable: # during every model forward for piecewise cudagraph - # mode, we will capture many pieces of cudagraphs - # (roughly one per layer). running gc again and again + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again # across layers will make the cudagraph capture very slow. # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. @@ -177,4 +178,4 @@ def __call__(self, *args, **kwargs): f"{new_input_addresses}") entry.cudagraph.replay() - return entry.output \ No newline at end of file + return entry.output diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index a4afc854306..95592a19892 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -4,18 +4,18 @@ import dataclasses from typing import Any, Callable, Optional -import torch import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import end_monitoring_torch_compile -from vllm.config import VllmConfig, CUDAGraphRuntimeStyle +from vllm.config import CUDAGraphRuntimeStyle, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname + logger = init_logger(__name__) @@ -65,7 +65,6 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) - self.first_run_finished = False @@ -90,14 +89,15 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.cudagraph_runable: Optional[CUDAGraphWrapper] = None if self.compilation_config.cudagraph_mode > 0: cudagraph_specific_config = { - "debug_capturing": self.is_first_graph, - "gc_disable": not self.is_first_graph, - "weak_ref_output": self.is_last_graph, - "usage_type" : usage_type } - - # Note: To easier distinguish whether it is under the - # piecewise backend, we always assume CUDAGraphRuntimeStyle.PIECEWISE - # here, no matter it is on a full fx graph or piecewise fx graph. + "debug_capturing": self.is_first_graph, + "gc_disable": not self.is_first_graph, + "weak_ref_output": self.is_last_graph, + "usage_type": usage_type + } + + # Note: To easier distinguish whether it is under the + # piecewise backend, we always assume PIECEWISE here, + # no matter it is on a full fx graph or piecewise fx graph. static_graph_wrapper_class = resolve_obj_by_qualname( current_platform.get_static_graph_wrapper_cls()) @@ -105,12 +105,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.compiled_graph_for_general_shape, vllm_config, self.graph_pool, - runtime_style = CUDAGraphRuntimeStyle.PIECEWISE, - cudagraph_specific_config = cudagraph_specific_config) - + runtime_style=CUDAGraphRuntimeStyle.PIECEWISE, + cudagraph_specific_config=cudagraph_specific_config) + self.cudagraph_capture_sizes = (self.compilation_config.\ cudagraph_capture_sizes) - # We now only keep compilation management inside this class directly. # The cudagraph logic is delegated to the CUDAGraphWrapper class. @@ -123,7 +122,6 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, usage_type=usage_type, # for debug logging only ) - def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile @@ -136,10 +134,10 @@ def __call__(self, *args) -> Any: self.first_run_finished = True self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) - + runtime_shape = args[self.sym_shape_indices[0]] if self.is_debugging_mode: - assert runtime_shape==get_forward_context().num_tokens + assert runtime_shape == get_forward_context().num_tokens if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape @@ -159,21 +157,20 @@ def __call__(self, *args) -> Any: graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape) - + # replace the runnable with the compiled one for # cudagraph capturing if self.cudagraph_runable is not None: - self.cudagraph_runable.maybe_replace_runnable(runtime_shape, - entry.runnable) + self.cudagraph_runable.maybe_replace_runnable( + runtime_shape, entry.runnable) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - + if not entry.use_cudagraph: return entry.runnable(*args) - + # safety check to ensure the cudagraph runnable is not None assert self.cudagraph_runable is not None return self.cudagraph_runable(*args) - diff --git a/vllm/config.py b/vllm/config.py index e5141d0ca65..f8ccfc55927 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3890,18 +3890,21 @@ class CompilationLevel: DYNAMO_ONCE = 2 PIECEWISE = 3 + class CUDAGraphMode: # constants for the config of the cudagraph mode NONE = 0 PIECEWISE = 1 FULL = 2 + class CUDAGraphRuntimeStyle: # constants same as CUDAGraphMode, but used for runtime dispatching NONE = 0 PIECEWISE = 1 FULL = 2 + @config @dataclass class PassConfig: @@ -4059,8 +4062,8 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - cudagraph_mode: int = field(default_factory= lambda: - 1 if envs.VLLM_USE_V1 else 0) + cudagraph_mode: int = field( + default_factory=lambda: 1 if envs.VLLM_USE_V1 else 0) """ The mode of the cudagraph. - 0: NONE, no cudagraph capture. @@ -4320,7 +4323,7 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called if self.separate_attention_routine: - assert self.cudagraph_mode==CUDAGraphMode.FULL, ( + assert self.cudagraph_mode == CUDAGraphMode.FULL, ( "separate_attention_routine requires " "cudagraph_mode be CUDAGraphMode.FULL") @@ -4335,12 +4338,12 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", ] - elif len(self.splitting_ops)==0: - assert self.cudagraph_mode==CUDAGraphMode.FULL, ( - "Seting splitting_ops as empty list requires " - "cudagraph_mode be CUDAGraphMode.FULL") - - self.splitting_ops = [] + elif len(self.splitting_ops) == 0: + assert self.cudagraph_mode == CUDAGraphMode.FULL, ( + "Seting splitting_ops as empty list requires " + "cudagraph_mode be CUDAGraphMode.FULL") + + self.splitting_ops = [] @config diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 50ebdc27839..f09c1faba63 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,7 +11,7 @@ import torch.distributed as dist import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig, CUDAGraphRuntimeStyle +from vllm.config import CUDAGraphRuntimeStyle, ParallelConfig, VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 70c2b5b880e..72605b1c3cc 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -26,8 +26,8 @@ from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches, AttentionCGSupport) + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + get_kv_cache_layout, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -175,7 +175,7 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, if not self.aot_schedule: raise ValueError( "AoT scheduling is required for full cuda graph.") - capture_sizes = compilation_config.cudagraph_capture_sizes + capture_sizes = self.runner.compilation_config.cudagraph_capture_sizes # noqa: E501 if not capture_sizes: raise ValueError( "cudagraph_capture_sizes should not be None when " diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 21eb37011c7..e5e80f82914 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,13 +15,13 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config, CUDAGraphMode +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, - get_kv_cache_layout, - AttentionCGSupport) + get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index d37f671dcb4..6bea1991703 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,8 +18,8 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches, AttentionCGSupport) + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2902dabaf90..71a8397eefb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -46,15 +46,16 @@ class CommonAttentionMetadata: M = TypeVar("M") + class AttentionCGSupport: # constants for the cudagraph support of the attention backend ALWAYS = 2 # Cudagraph always supported # Cudagraph supported for pure decode, need to use piecewise # if mixed prefill-decode batches - PURE_DECODE_ONLY = 1 - NEVER = 0 # No support - + PURE_DECODE_ONLY = 1 + NEVER = 0 # No support + class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. @@ -62,13 +63,13 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # If attn_cudagraph_supported >0, attention backend can set its # preference of separate rountine to be True, False or None. # True: expect to explicit separate routines for capturing cudagraph - # of pure decode batches and mixed batches. Should be true if + # of pure decode batches and mixed batches. Should be true if # attn_cudagraph_supported is PURE_DECODE_ONLY. And may be faster # to set it true if attn_cudagraph_supported is ALWAYS. - # False: expect to keep a unified kernel routine when - # attn_cudagraph_supported is ALWAYS. It is the case if an - # attention kernel can dynamically dispatch different optimzied - # rountines inside a kernel, so no need to manually separate them + # False: expect to keep a unified kernel routine when + # attn_cudagraph_supported is ALWAYS. It is the case if an + # attention kernel can dynamically dispatch different optimzied + # rountines inside a kernel, so no need to manually separate them # outside kernel when capturing cudagraph. # None: indicates no specific preference, and the control is left # to the users. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0906d3ded3d..0fc3882aaec 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,12 +18,12 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention +from vllm.compilation.backends import get_global_graph_pool from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper -from vllm.compilation.backends import get_global_graph_pool -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config, - CUDAGraphMode, CUDAGraphRuntimeStyle) +from vllm.config import (CompilationLevel, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -50,9 +50,9 @@ check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - AttentionCGSupport) +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, @@ -93,6 +93,7 @@ # constant code pure decode DECODE_BOOLEN = True + class GPUModelRunner(LoRAModelRunnerMixin): def __init__( @@ -330,7 +331,6 @@ def __init__( CompilationLevel.PIECEWISE or self.model_config.enforce_eager def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: - """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -1366,9 +1366,9 @@ def execute_model( else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - + cudagraph_runtime_style = self._cudagraph_runtime_style( - attention_cuda_graphs) + attention_cuda_graphs) # Note: When cudagraph_mode is FULL and # compilation_config.separate_attention_routine is True, as in FA2, # this flag helps to determine the correct routine for the full @@ -1832,7 +1832,9 @@ def load_model(self) -> None: # for profile run. # Note that self.model always support no cudagraph. self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.NONE,): self.model}) + (CUDAGraphRuntimeStyle.NONE, ): + self.model + }) def save_tensorized_model( self, @@ -2052,7 +2054,7 @@ def _dummy_run( for b in self.attn_metadata_builders) cudagraph_runtime_style = self._cudagraph_runtime_style( attention_cuda_graphs) - + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -2064,7 +2066,7 @@ def _dummy_run( attn_metadata[layer_name] = attn_metadata_i else: attn_metadata = None # reset to None other than empty dict - + if is_profile: # when profiling, _maybe_initialize_cudagraph() is not called, # so always run no cudagraph. @@ -2369,12 +2371,13 @@ def capture_model(self) -> None: start_idx = 1 # We skip EPLB here since we don't want to record dummy metrics - + # Only rank 0 should print progress bar during capture - compilation_cases = reversed(self.cudagraph_batch_sizes[ - start_idx:]) + compilation_cases = reversed( + self.cudagraph_batch_sizes[start_idx:]) if is_global_first_rank(): - compilation_cases = tqdm(list(compilation_cases), + compilation_cases = tqdm( + list(compilation_cases), desc="Capturing CUDA graphs (mix prefill-decode)") # Capture the mix prefill-decode (general usage) cudagraphs for num_tokens in compilation_cases: @@ -2399,12 +2402,12 @@ def capture_model(self) -> None: x for x in self.cudagraph_batch_sizes if x <= max_num_reqs ] compilation_cases_decode = reversed( - decode_cudagraph_batch_sizes) + decode_cudagraph_batch_sizes) if is_global_first_rank(): - compilation_cases_decode = tqdm(list( - compilation_cases_decode), - desc="Capturing CUDA graphs (pure decode)") - + compilation_cases_decode = tqdm( + list(compilation_cases_decode), + desc="Capturing CUDA graphs (pure decode)") + for num_tokens in tqdm( reversed(decode_cudagraph_batch_sizes), desc="Capturing CUDA graphs (pure decode)", @@ -2427,28 +2430,29 @@ def capture_model(self) -> None: # This usually takes 5~20 seconds. logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - + def _maybe_initialize_cudagraph(self): - + if self.compilation_config.level == CompilationLevel.PIECEWISE\ and len(self.compilation_config.splitting_ops)>0: self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.PIECEWISE,): self.model}) + (CUDAGraphRuntimeStyle.PIECEWISE, ): + self.model + }) logger.debug("Piecewise cudagraph initialized") - + if self.full_cuda_graph: attn_cg = self.attn_metadata_builders[0].attn_cudagraph_support # create full cudagraph for mix prefill-decode/general batches if attn_cg == AttentionCGSupport.ALWAYS: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): - CUDAGraphWrapper( - self.model, self.vllm_config, + CUDAGraphWrapper( + self.model, + self.vllm_config, get_global_graph_pool(), runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_specific_config={ - "usage_type": "general" - }) + cudagraph_specific_config={"usage_type": "general"}) }) logger.debug("Full cudagraph for mixed batches initialized") # create full cudagraph for pure decode batches @@ -2457,18 +2461,18 @@ def _maybe_initialize_cudagraph(self): self.compilation_config.separate_attention_routine): self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): - CUDAGraphWrapper( - self.model, self.vllm_config, + CUDAGraphWrapper( + self.model, + self.vllm_config, get_global_graph_pool(), runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_specific_config={ - "usage_type": "decode" - }) + cudagraph_specific_config={"usage_type": "decode"}) }) - logger.debug("Full cudagraph for pure decode batches initialized") + logger.debug( + "Full cudagraph for pure decode batches initialized") def _cudagraph_runtime_style(self, attn_cuda_graphs): - + # Some attention backends only support CUDA Graphs in pure decode. # If attention doesn't support CUDA Graphs for this batch, we skip them, # and turn back to the piecewise CUDA graphs. @@ -2476,34 +2480,34 @@ def _cudagraph_runtime_style(self, attn_cuda_graphs): attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE cudagraph_runtime_style = min(self.cudagraph_mode, cudagraph_runtime_style) - + # PIECEWISE would fall back to NONE if no compilation if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ self.no_compilation: cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE - + #TODO: can we optimize above logic? return cudagraph_runtime_style - @contextmanager def cudagraph_dispatch(self, cudagraph_runtime_style: int, is_pure_decode: bool): old_model = self.model assert self.cudagraph_candidates, ("cudagraph_candidates are " - "not initialized.") + "not initialized.") # select between no cudagraph and piecewise cudagraph - if cudagraph_runtime_style in [CUDAGraphRuntimeStyle.NONE, - CUDAGraphRuntimeStyle.PIECEWISE]: + if cudagraph_runtime_style in [ + CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE + ]: self.model = self.cudagraph_candidates.get( - (cudagraph_runtime_style,), None) + (cudagraph_runtime_style, ), None) else: # for full cudagraph, select between general batches # or pure decode batches decode_case = (DECODE_BOOLEN,) if self.compilation_config.\ separate_attention_routine and is_pure_decode \ else (not DECODE_BOOLEN,) - tuple_key = (cudagraph_runtime_style,) + decode_case + tuple_key = (cudagraph_runtime_style, ) + decode_case self.model = self.cudagraph_candidates.get(tuple_key, None) assert self.model is not None, ("cudagraph_candidates is not " "correctly initialized for" @@ -2596,7 +2600,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) - # Trigger cudagraph initialization here (after + # Trigger cudagraph initialization here (after # initializing attn backends). # TODO: move this to the better place self._maybe_initialize_cudagraph() From fedff4746411440abf37e7a1b2e9a9559345c127 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 5 Jul 2025 13:46:11 +0000 Subject: [PATCH 14/33] fix errors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index f8ccfc55927..3d70df54661 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3994,7 +3994,7 @@ class CompilationConfig: certain small batchsizes, where inductor is good at optimizing. """ # Top-level Compilation control - level: int = -1 # -1 for no user-setting, VllmConfig.__post_init__ will handle it # noqa + level: Optional[int] = None """The level of compilation: - 0: no compilation. @@ -4615,10 +4615,14 @@ def __post_init__(self): # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph # is set to True, full CUDA graphs will be used. self.compilation_config.cudagraph_num_of_warmups = 1 - if self.compilation_config.level == -1: + if self.compilation_config.level is None: self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() + # For V0 or other cases, default to level 0 with no compilation + if self.compilation_config.level is None: + self.compilation_config.level = CompilationLevel.NO_COMPILATION + self._set_cudagraph_sizes() if self.cache_config.cpu_offload_gb > 0 and \ From 833ac5686896e4b80a830abef39347414ba81a64 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 5 Jul 2025 14:53:55 +0000 Subject: [PATCH 15/33] fix small error by lazy import Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/cuda_graph.py | 8 ++++++-- vllm/compilation/piecewise_backend.py | 5 ++--- vllm/v1/worker/gpu_model_runner.py | 3 --- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 0c5e0b89592..49542b79172 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -42,8 +42,8 @@ class CUDAGraphWrapper: def __init__(self, runnable: Any, vllm_config: VllmConfig, - graph_pool: Any, runtime_style: int, + graph_pool: Any = None, cudagraph_specific_config: Optional[dict[str, Any]] = None): self.runnable = runnable self.vllm_config = vllm_config @@ -55,7 +55,11 @@ def __init__(self, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" assert self.runtime_style >= CUDAGraphRuntimeStyle.PIECEWISE - assert graph_pool is not None + if self.graph_pool is None: + # lazy import to avoid triggering some import issues. + from vllm.compilation.backends import get_global_graph_pool + self.graph_pool = get_global_graph_pool() + if cudagraph_specific_config is None: cudagraph_specific_config = {} self.debug_capturing = cudagraph_specific_config.get( diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 95592a19892..7a9a12da13c 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -8,7 +8,6 @@ import vllm.envs as envs from vllm.compilation.backends import VllmBackend -from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import CUDAGraphRuntimeStyle, VllmConfig from vllm.forward_context import get_forward_context @@ -86,7 +85,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, "piecewise/general" self.cudagraph_capture_sizes: set[int] = set() - self.cudagraph_runable: Optional[CUDAGraphWrapper] = None + self.cudagraph_runable: Optional[Any] = None if self.compilation_config.cudagraph_mode > 0: cudagraph_specific_config = { "debug_capturing": self.is_first_graph, @@ -104,8 +103,8 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.cudagraph_runable = static_graph_wrapper_class( self.compiled_graph_for_general_shape, vllm_config, - self.graph_pool, runtime_style=CUDAGraphRuntimeStyle.PIECEWISE, + graph_pool=self.graph_pool, cudagraph_specific_config=cudagraph_specific_config) self.cudagraph_capture_sizes = (self.compilation_config.\ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0fc3882aaec..33ed6ecfc88 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,7 +18,6 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention -from vllm.compilation.backends import get_global_graph_pool from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import (CompilationLevel, CUDAGraphMode, @@ -2450,7 +2449,6 @@ def _maybe_initialize_cudagraph(self): CUDAGraphWrapper( self.model, self.vllm_config, - get_global_graph_pool(), runtime_style=CUDAGraphRuntimeStyle.FULL, cudagraph_specific_config={"usage_type": "general"}) }) @@ -2464,7 +2462,6 @@ def _maybe_initialize_cudagraph(self): CUDAGraphWrapper( self.model, self.vllm_config, - get_global_graph_pool(), runtime_style=CUDAGraphRuntimeStyle.FULL, cudagraph_specific_config={"usage_type": "decode"}) }) From d57257dd5e98658adc3fd6336a4a5abbb322dfb6 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 5 Jul 2025 15:36:27 +0000 Subject: [PATCH 16/33] handle lint-and-deploy errors for cpu execution Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 33ed6ecfc88..bf7ac6d7f67 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2489,9 +2489,15 @@ def _cudagraph_runtime_style(self, attn_cuda_graphs): @contextmanager def cudagraph_dispatch(self, cudagraph_runtime_style: int, is_pure_decode: bool): + # if no cudagraph candidates inside other platforms, + # just skip cudagraph dispatching. + if not self.cudagraph_candidates: + logger.warning_once("cudagraphs are not initialized." + " No cudagraph will be used.") + yield + return + old_model = self.model - assert self.cudagraph_candidates, ("cudagraph_candidates are " - "not initialized.") # select between no cudagraph and piecewise cudagraph if cudagraph_runtime_style in [ CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE From 8b7ea7aa136da891e8e3fd50a505c515fcce6744 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sat, 5 Jul 2025 15:56:18 +0000 Subject: [PATCH 17/33] remove redundents Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 9 --------- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1e01e104b16..b352682f187 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -590,15 +590,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: any(is_symbolic(d) for d in x.size()) ] - # if self.compilation_config.full_cuda_graph: - # assert self.compilation_config.use_cudagraph, \ - # "full_cuda_graph mode requires use_cudagraph to be True" - # fullgraph_wrapper = resolve_obj_by_qualname( - # current_platform.get_fullgraph_wrapper_cls()) - # self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config, - # self.graph_pool, - # self.sym_tensor_indices) - # compiler managed cudagraph input buffers # we assume the first run with symbolic shapes # has the maximum size among all the tensors diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bf7ac6d7f67..48aad6037fb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2605,7 +2605,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # Trigger cudagraph initialization here (after # initializing attn backends). - # TODO: move this to the better place + # TODO: move this to better place. self._maybe_initialize_cudagraph() def may_reinitialize_input_batch(self, From 328615d308d716c1fee888ffe30b4286884b0399 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 6 Jul 2025 03:19:09 +0000 Subject: [PATCH 18/33] Clear Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 4 ++-- vllm/v1/attention/backends/flash_attn.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3d70df54661..c49b3e90e1d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4086,8 +4086,8 @@ class CompilationConfig: Note that the cudagraph logic is generally orthogonal to the compilation logic. For piecewise cudagraph, the logic is kept inside the compilation. Meanwhile, the full cudagraph is captured - outside the compilation, and in future it will further supports - cudagraph without compilation. + outside the compilation, and it further supports cudagraph + without compilation. """ use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) """Whether to use cudagraph inside compilation. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 72605b1c3cc..a424ece0c0f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -151,8 +151,6 @@ class FlashAttentionMetadataBuilder( # TODO: change the default preference if needed. prefer_separate_routine: ClassVar[Optional[bool]] = None - support_full_cudagraph_only: ClassVar[bool] = True - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): model_config = runner.model_config From debc682ce4db3c5b9f4d6ea89ff3fa53ae0aa81a Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Thu, 10 Jul 2025 01:54:33 +0800 Subject: [PATCH 19/33] Big refactors Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/backends.py | 49 ++- vllm/compilation/base_static_graph.py | 13 +- vllm/compilation/cuda_graph.py | 100 +++--- vllm/compilation/piecewise_backend.py | 72 +---- vllm/compilation/wrapper.py | 6 +- vllm/config.py | 49 ++- vllm/forward_context.py | 4 +- vllm/platforms/cuda.py | 9 +- vllm/platforms/interface.py | 13 +- vllm/v1/attention/backends/flash_attn.py | 54 ++-- vllm/v1/attention/backends/flashinfer.py | 13 +- vllm/v1/attention/backends/mla/flashmla.py | 4 +- .../attention/backends/mla/rocm_aiter_mla.py | 4 +- vllm/v1/attention/backends/triton_attn.py | 3 +- vllm/v1/attention/backends/utils.py | 36 +-- vllm/v1/cudagraph_dispatcher.py | 144 +++++++++ vllm/v1/worker/gpu_model_runner.py | 286 +++++++----------- 17 files changed, 437 insertions(+), 422 deletions(-) create mode 100644 vllm/v1/cudagraph_dispatcher.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b352682f187..3d79d6a1cf4 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,10 +15,11 @@ from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import (CompilationConfig, VllmConfig, CUDAGraphMode, + CUDAGraphRuntimeStyle) from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import is_torch_equal_or_newer +from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname from .compiler_interface import (CompilerInterface, EagerAdaptor, InductorAdaptor, InductorStandaloneAdaptor) @@ -255,16 +256,6 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs -# we share the global graph pool among all the backends -global_graph_pool = None - - -def get_global_graph_pool(): - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() - return global_graph_pool - compilation_start_time = 0.0 @@ -327,10 +318,36 @@ def call_module(self, target: torch.fx.node.Target, runtime_shape=None) # Lazy import here to avoid circular import from .piecewise_backend import PiecewiseBackend - self.module.__dict__[target] = PiecewiseBackend( - submod, self.vllm_config, self.graph_pool, index, + from .cuda_graph import CUDAGraphOptions + + piecewise_backend = PiecewiseBackend( + submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape, self.vllm_backend) + + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper + # class) as platform dependent. + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + # Always assign PIECEWISE runtime style to the + # CUDAGraphWrapper for piecewise_backend, to distinguish + # it from the FULL cudagraph runtime style, no matter it + # is wrapped on a full or piecewise fx graph. + self.module.__dict__[target] = static_graph_wrapper_class( + piecewise_backend, + self.vllm_config, + CUDAGraphRuntimeStyle.PIECEWISE, + self.graph_pool, + cudagraph_options = CUDAGraphOptions( + debug_log_enable=piecewise_backend.is_first_graph, + gc_disable=not piecewise_backend.is_first_graph, + weak_ref_output=piecewise_backend.is_last_graph, + usage_str="piecewise" + )) + else: + self.module.__dict__[target] = piecewise_backend compilation_counter.num_piecewise_capturable_graphs_seen += 1 @@ -398,7 +415,7 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global_graph_pool = get_global_graph_pool() + global_graph_pool = current_platform.get_global_graph_pool() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -568,7 +585,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - if not self.compilation_config.use_cudagraph or \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ not self.compilation_config.cudagraph_copy_inputs: return self.split_gm diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 77c7cd09e27..9e1e9477051 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum from typing import Any, Callable, Protocol from vllm.config import VllmConfig @@ -13,7 +14,7 @@ class AbstractStaticGraphWrapper(Protocol): """ def __init__(self, runnable: Callable, vllm_config: VllmConfig, - graph_pool: Any, runtime_style: int, **kwargs): + graph_pool: Any, runtime_style: enum.Enum, **kwargs): """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -24,20 +25,14 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, graph_pool (Any): Graph memory pool handle, e.g., `torch.cuda.graph_pool_handle()`. - runtime_style (Any): The style of the static - graph runtime. + runtime_style (enum.Enum): The style of the static + graph runtime. e.g. see CUDAGraphRuntimeStyle in vllm/config.py. Keyword Args: kwargs: Additional keyword arguments for platform-specific configurations. """ raise NotImplementedError - def maybe_replace_runnable(self, shape: int, runnable: Any): - """ - Replaces the runnable with a new one for a specific compiled shape. - """ - raise NotImplementedError - def __call__(self, *args, **kwargs) -> Any: """ Executes the wrapped callable. diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 49542b79172..3369eeff3dd 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -14,6 +14,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import weak_ref_tensors +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -21,16 +22,20 @@ @dataclasses.dataclass class CUDAGraphEntry: runtime_shape: int - num_finished_warmup: int = 0 - runnable: Callable = None # type: ignore cudagraph: Optional[torch.cuda.CUDAGraph] = None output: Optional[Any] = None # for cudagraph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: Optional[list[int]] = None - usage_type: Optional[str] = None # For debug logging only + +@dataclasses.dataclass +class CUDAGraphOptions: + debug_log_enable: bool = True + gc_disable: bool = False + weak_ref_output: bool = True + usage_str: Optional[str] = None # For debug logging only class CUDAGraphWrapper: @@ -40,11 +45,11 @@ class CUDAGraphWrapper: """ def __init__(self, - runnable: Any, + runnable: Callable, vllm_config: VllmConfig, - runtime_style: int, - graph_pool: Any = None, - cudagraph_specific_config: Optional[dict[str, Any]] = None): + runtime_style: CUDAGraphRuntimeStyle, + graph_pool: Any = current_platform.get_global_graph_pool(), + cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config self.graph_pool = graph_pool @@ -54,41 +59,23 @@ def __init__(self, self.first_run_finished = False self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - assert self.runtime_style >= CUDAGraphRuntimeStyle.PIECEWISE - if self.graph_pool is None: - # lazy import to avoid triggering some import issues. - from vllm.compilation.backends import get_global_graph_pool - self.graph_pool = get_global_graph_pool() - - if cudagraph_specific_config is None: - cudagraph_specific_config = {} - self.debug_capturing = cudagraph_specific_config.get( - "debug_capturing", True) - self.gc_disable = cudagraph_specific_config.get("gc_disable", False) - self.weak_ref_output = cudagraph_specific_config.get( - "weak_ref_output", True) - usage_type = cudagraph_specific_config.get("usage_type") + # assert runtime_style is not NONE(no cudagraph), otherwise, we don't + # need to initialize a CUDAGraphWrapper. + assert self.runtime_style != CUDAGraphRuntimeStyle.NONE + assert self.graph_pool is not None + + if cudagraph_options is None: + cudagraph_options = CUDAGraphOptions() + self.cudagraph_options = cudagraph_options + self.cudagraph_capture_sizes: set[int] = set( self.compilation_config.cudagraph_capture_sizes) # the entries for different shapes that we need to capture cudagraph self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {} for shape in self.cudagraph_capture_sizes: - self.concrete_cudagraph_entries[shape] = CUDAGraphEntry( - runtime_shape=shape, - runnable=self.runnable, - usage_type=usage_type, # for debug logging only - ) - - def maybe_replace_runnable(self, shape: int, runnable: Callable): - # this is a hack to replace a general shape runnable with a compiled - # runnable of a specific shape. - if shape not in self.concrete_cudagraph_entries: - return - entry = self.concrete_cudagraph_entries[shape] - assert entry.cudagraph is None, "Cudagraph is already captured" - entry.runnable = runnable + runtime_shape=shape) def __call__(self, *args, **kwargs): forward_context = get_forward_context() @@ -97,12 +84,17 @@ def __call__(self, *args, **kwargs): if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\ runtime_shape is None: - # TODO: make sure here is on profile running or eager running + # make sure it's on profile run, eager run, or warmup stage. return self.runnable(*args, **kwargs) if cudagraph_runtime_style != self.runtime_style: - # CUDAGraph runtime style don't match the current - # configuration, so directly call runnable eagerly - # as it's always safe. + # Only triggers capture/replay if the runtime style matches, + # otherwise, we fallback to the original runnable to handle + # no match case. This is a hack to avoid double capturing + # cudagraph and ensure extra safety in situations where we + # have nested CUDAdGraphWrapper structure, e.g., we have + # piecewise cudagraph for piecewise backend, which may be + # further wrapped to obtain a full cudagraph. See #20059 for + # more details. return self.runnable(*args, **kwargs) if runtime_shape not in self.concrete_cudagraph_entries: @@ -112,22 +104,13 @@ def __call__(self, *args, **kwargs): entry = self.concrete_cudagraph_entries[runtime_shape] if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.debug_capturing: - logger.debug( - "Warming up %s/%s of %s usage for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - entry.usage_type, entry.runtime_shape) - return entry.runnable(*args, **kwargs) - - if self.debug_capturing: + if self.cudagraph_options.debug_log_enable: # Since we capture cudagraph for many different shapes and # capturing is fast, we don't need to log it for every # shape. We only log it in the debug mode. logger.debug("Capturing a cudagraph of %s usage for shape %s", - entry.usage_type, entry.runtime_shape) + self.cudagraph_options.usage_str, + entry.runtime_shape) input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) @@ -136,7 +119,7 @@ def __call__(self, *args, **kwargs): cudagraph = torch.cuda.CUDAGraph() with ExitStack() as stack: - if self.gc_disable: + if self.cudagraph_options.gc_disable: # during every model forward for piecewise cudagraph # mode, we will capture many pieces of cudagraphs # (roughly one per layer). running gc again and again @@ -150,8 +133,8 @@ def __call__(self, *args, **kwargs): # mind-exploding: carefully manage the reference and memory. with torch.cuda.graph(cudagraph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args, **kwargs) - if self.weak_ref_output: + output = self.runnable(*args, **kwargs) + if self.cudagraph_options.weak_ref_output: # by converting it to weak ref, # the original `output` will immediately be released # to save memory. It is only safe to do this for @@ -177,9 +160,10 @@ def __call__(self, *args, **kwargs): x.data_ptr() for x in args if isinstance(x, torch.Tensor) ] assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during " - f"replay. Expected {entry.input_addresses}, got " - f"{new_input_addresses}") + f"Input addresses for cudagraphs of " + f"{self.cudagraph_options.usage_str} are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}") entry.cudagraph.replay() return entry.output diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 7a9a12da13c..982118a114c 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -21,37 +21,28 @@ @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes compiled: bool = False runnable: Callable = None # type: ignore - usage_type: Optional[str] = None # For debug logging only - class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], + piecewise_compile_index: int, total_piecewise_compiles: int, + sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. + It mainly handles the compilation. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. - - Independently, the static graph capturing (e.g. CUDA graph) is handled - by a separate static graph wrapper, which is expected to wrap the - compiled callable of the general shape. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend @@ -73,52 +64,18 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to either - # compile or capture cudagraph + # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - - usage_type = "full/general" if self.is_full_graph else \ - "piecewise/general" - - self.cudagraph_capture_sizes: set[int] = set() - self.cudagraph_runable: Optional[Any] = None - if self.compilation_config.cudagraph_mode > 0: - cudagraph_specific_config = { - "debug_capturing": self.is_first_graph, - "gc_disable": not self.is_first_graph, - "weak_ref_output": self.is_last_graph, - "usage_type": usage_type - } - - # Note: To easier distinguish whether it is under the - # piecewise backend, we always assume PIECEWISE here, - # no matter it is on a full fx graph or piecewise fx graph. - - static_graph_wrapper_class = resolve_obj_by_qualname( - current_platform.get_static_graph_wrapper_cls()) - self.cudagraph_runable = static_graph_wrapper_class( - self.compiled_graph_for_general_shape, - vllm_config, - runtime_style=CUDAGraphRuntimeStyle.PIECEWISE, - graph_pool=self.graph_pool, - cudagraph_specific_config=cudagraph_specific_config) - - self.cudagraph_capture_sizes = (self.compilation_config.\ - cudagraph_capture_sizes) - - # We now only keep compilation management inside this class directly. - # The cudagraph logic is delegated to the CUDAGraphWrapper class. - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + + # We only keep compilation management inside this class directly. + for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, runnable=self.compiled_graph_for_general_shape, - usage_type=usage_type, # for debug logging only ) def check_for_ending_compilation(self): @@ -144,7 +101,7 @@ def __call__(self, *args) -> Any: entry = self.concrete_size_entries[runtime_shape] - if entry.need_to_compile and not entry.compiled: + if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments @@ -157,19 +114,8 @@ def __call__(self, *args) -> Any: num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape) - # replace the runnable with the compiled one for - # cudagraph capturing - if self.cudagraph_runable is not None: - self.cudagraph_runable.maybe_replace_runnable( - runtime_shape, entry.runnable) - # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - if not entry.use_cudagraph: - return entry.runnable(*args) - - # safety check to ensure the cudagraph runnable is not None - assert self.cudagraph_runable is not None - return self.cudagraph_runable(*args) + return entry.runnable(*args) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 2a261c84c3f..2f2349474ea 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,7 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import CompilationLevel, get_current_vllm_config, CUDAGraphMode from vllm.logger import init_logger logger = init_logger(__name__) @@ -113,8 +113,8 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): except Exception: pass - if self.vllm_config.compilation_config.use_cudagraph and \ - "update" in new_code.co_names: + if self.vllm_config.compilation_config.cudagraph_mode != \ + CUDAGraphMode.NONE and "update" in new_code.co_names: import depyf src = depyf.decompile(new_code) msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa diff --git a/vllm/config.py b/vllm/config.py index c49b3e90e1d..9bf8761838e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3891,15 +3891,16 @@ class CompilationLevel: PIECEWISE = 3 -class CUDAGraphMode: - # constants for the config of the cudagraph mode +class CUDAGraphMode(enum.Enum): + # constants for the config of the cudagraph mode. NONE = 0 PIECEWISE = 1 FULL = 2 -class CUDAGraphRuntimeStyle: - # constants same as CUDAGraphMode, but used for runtime dispatching +class CUDAGraphRuntimeStyle(enum.Enum): + # constants for concrete cudagraph runtime style, used for + # runtime dispatching. NONE = 0 PIECEWISE = 1 FULL = 2 @@ -4062,16 +4063,17 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - cudagraph_mode: int = field( - default_factory=lambda: 1 if envs.VLLM_USE_V1 else 0) + cudagraph_mode: CUDAGraphMode = field( + default_factory=lambda: CUDAGraphMode.PIECEWISE if envs.VLLM_USE_V1 + else CUDAGraphMode.NONE) """ The mode of the cudagraph. - - 0: NONE, no cudagraph capture. - - 1: PIECEWISE. (v1 default) - - 2: FULL. - For cudagraph_mode > 0, It requires that all input buffers have - fixed addresses and all splitting ops write their outputs to - input buffers. + - NONE, no cudagraph capture. + - PIECEWISE. (v1 default) + - FULL. + For cudagraph_mode != CUDAGraphMode.NONE, it requires that all input + buffers have fixed addresses and all splitting ops write their outputs + to input buffers. PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph incompatiable ops (i.e. some attention ops) outside the cudagraph @@ -4089,19 +4091,6 @@ class CompilationConfig: outside the compilation, and it further supports cudagraph without compilation. """ - use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) - """Whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - In the vLLM V1 Engine, this flag only applies for - CompilationLevel.PIECEWISE (aka -O3). - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. - TODO: Now this flag is treated as a placeholder for cudagraph - control inside compilation, will removed it in future. - """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -4211,6 +4200,16 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """ return TypeAdapter(CompilationConfig).validate_json(cli_value) + @field_validator("cudagraph_mode", mode="before") + @classmethod + def validate_cudagraph_mode_before(cls, value: Any) -> Any: + """ + enable parse the `cudagraph_mode` enum type from string + """ + if isinstance(value, str): + return CUDAGraphMode[value.upper()] + return value + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") diff --git a/vllm/forward_context.py b/vllm/forward_context.py index f09c1faba63..5850fd2003c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -97,7 +97,7 @@ class ForwardContext: dp_metadata: Optional[DPMetadata] = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. - cudagraph_runtime_style: int = CUDAGraphRuntimeStyle.NONE + cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE _forward_context: Optional[ForwardContext] = None @@ -118,7 +118,7 @@ def set_forward_context( virtual_engine: int = 0, num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, - cudagraph_runtime_style: int = CUDAGraphRuntimeStyle.NONE, + cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE, ): """A context manager that stores the current forward context, can be attention metadata, etc. diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 7ec1bf57244..0bbd6f55d2f 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -164,17 +164,22 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") + + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and vllm_config.compilation_config.use_cudagraph): + and vllm_config.compilation_config.cudagraph_mode + != CUDAGraphMode.NONE): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - vllm_config.compilation_config.use_cudagraph = False + + vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.NONE vllm_config.model_config.enforce_eager = True # TODO (varun): Turning this ON gives incorrect results for the # Deepseek-V2-lite model. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index a70af2a768a..463e872178a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import sys from datetime import timedelta from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import numpy as np import torch @@ -30,6 +30,8 @@ SamplingParams = None FlexibleArgumentParser = None +_global_graph_pool = None + logger = init_logger(__name__) @@ -516,6 +518,15 @@ def __getattr__(self, key: str): logger.warning("Current platform %s does not have '%s'" \ " attribute.", self.device_type, key) return None + + def get_global_graph_pool(self) -> Any: + """ + Return the global graph pool for the this platform. + """ + global _global_graph_pool + if _global_graph_pool is None: + _global_graph_pool = self.graph_pool_handle() + return _global_graph_pool @classmethod def get_cu_count(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a424ece0c0f..58a382cb894 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -22,7 +22,7 @@ get_scheduler_metadata, reshape_and_cache_flash) -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config, CUDAGraphMode from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -144,16 +144,18 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): - attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.ALWAYS - # FlashAttn support a unified varlen fwd kernel for prefill-decode phase, so - # it's ok to either separate attention routine or not for both FA2 or 3. - # TODO: change the default preference if needed. - prefer_separate_routine: ClassVar[Optional[bool]] = None + AttentionMetadataBuilder[FlashAttentionMetadata]): + # FA2 launches separte routines for prefill-decode and pure decode batches, + # while FA3 launches a unified varlen fwd kernel for both prefill-decode + # and pure decode batches. + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.ALWAYS_SEPARATE if get_flash_attn_version() == 2 \ + else AttentionCGSupport.ALWAYS_UNIFIED def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): model_config = runner.model_config + compilation_config = runner.vllm_config.compilation_config self.runner = runner self.num_heads_q = model_config.get_num_attention_heads( @@ -167,18 +169,12 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = self.runner.full_cuda_graph + + self.use_full_cuda_graph = ( + compilation_config.cudagraph_mode == CUDAGraphMode.FULL) - if self.use_full_cuda_graph: - if not self.aot_schedule: - raise ValueError( - "AoT scheduling is required for full cuda graph.") - capture_sizes = self.runner.compilation_config.cudagraph_capture_sizes # noqa: E501 - if not capture_sizes: - raise ValueError( - "cudagraph_capture_sizes should not be None when " - "full_cuda_graph is True.") - self.max_cudagraph_size = max(capture_sizes) + if self.use_full_cuda_graph and self.aot_schedule: + self.max_cudagraph_size = compilation_config.max_capture_size if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -330,9 +326,9 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, seqlens=seq_lens, max_seq_len=max_seq_len, causal=True) - - if self.use_full_cuda_graph: - assert scheduler_metadata is not None + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler @@ -341,15 +337,13 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # output buffer. self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - - max_num_splits = 0 - if (self.use_full_cuda_graph - and num_actual_tokens <= self.max_cudagraph_size): - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits + + if num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e5e80f82914..bd9e6e65d5b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,7 +15,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionCGSupport, @@ -219,8 +219,8 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.PURE_DECODE_ONLY - prefer_separate_routine: ClassVar[Optional[bool]] = True + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -229,14 +229,17 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - self.enable_cuda_graph = self.runner.full_cuda_graph + + compilation_config = self.vllm_config.compilation_config + self.enable_cuda_graph = (compilation_config.cudagraph_mode == + CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. self._decode_wrappers_cudagraph: dict[ int, BatchDecodeWithPagedKVCacheWrapper] = {} self._decode_cudagraph_max_bs = min( - runner.max_num_reqs, runner.cudagraph_batch_sizes[-1]) + runner.max_num_reqs, compilation_config.max_capture_size) self._cascade_wrapper = None # Wrapper for cascade attention diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index e4f06c47af7..78c6907efea 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -55,8 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.PURE_DECODE_ONLY - prefer_separate_routine: ClassVar[Optional[bool]] = True + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d5f9dfaea06..e72b02e3ac0 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -15,6 +15,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -63,7 +64,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # decode only + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, runner, kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6bea1991703..258f8a5c417 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -73,7 +73,8 @@ class LocalAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.ALWAYS + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.ALWAYS_SEPARATE def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 71a8397eefb..3ef030fed98 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,6 +4,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass +import enum from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar import numpy as np @@ -47,33 +48,26 @@ class CommonAttentionMetadata: M = TypeVar("M") -class AttentionCGSupport: - # constants for the cudagraph support of the attention backend +class AttentionCGSupport(enum.Enum): + # Constants for the cudagraph support of the attention backend + # Here we do not consider the cascade attention, as currently + # it is never cudagraph supported. - ALWAYS = 2 # Cudagraph always supported - # Cudagraph supported for pure decode, need to use piecewise - # if mixed prefill-decode batches - PURE_DECODE_ONLY = 1 NEVER = 0 # No support + PURE_DECODE_ONLY = 1 + # Cudagraph supported for pure decode, need to use piecewise + # cudagraph or no cudagraph for mixed prefill-decode batches + ALWAYS_UNIFIED = 2 + # Cudagraph always supported with unified routine + ALWAYS_SEPARATE = 3 + # Cudagraph supported for both mixed prefill-decode + # or pure decode attention routines. class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. - attn_cudagraph_support: ClassVar[int] = AttentionCGSupport.NEVER - # If attn_cudagraph_supported >0, attention backend can set its - # preference of separate rountine to be True, False or None. - # True: expect to explicit separate routines for capturing cudagraph - # of pure decode batches and mixed batches. Should be true if - # attn_cudagraph_supported is PURE_DECODE_ONLY. And may be faster - # to set it true if attn_cudagraph_supported is ALWAYS. - # False: expect to keep a unified kernel routine when - # attn_cudagraph_supported is ALWAYS. It is the case if an - # attention kernel can dynamically dispatch different optimzied - # rountines inside a kernel, so no need to manually separate them - # outside kernel when capturing cudagraph. - # None: indicates no specific preference, and the control is left - # to the users. - prefer_separate_routine: ClassVar[Optional[bool]] = None + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER @abstractmethod def build(self, common_prefix_len: int, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py new file mode 100644 index 00000000000..d3348dcc549 --- /dev/null +++ b/vllm/v1/cudagraph_dispatcher.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Any +from vllm.config import (CUDAGraphRuntimeStyle, VllmConfig, CompilationLevel, + CUDAGraphMode) +from vllm.v1.attention.backends.utils import AttentionCGSupport +from vllm.compilation.cuda_graph import CUDAGraphWrapper, CUDAGraphOptions + +if TYPE_CHECKING: + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +# constant for pure decode +DECODE_BOOLEN = True + + +class CudagraphDispatcher: + """ + Runtime cudagraph dispatcher for gpu model runner. + """ + + def __init__(self, runner: "GPUModelRunner", + vllm_config: VllmConfig): + self.runner = runner + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.no_compilation = self.runner.no_compilation + + # Dict to store cudagraph candidates for runtime dispatching. + self.cudagraph_candidates: dict[tuple, Any] = {} + + def after_load_model(self): + # add original model to cudagraph_candidates for profile run. + assert self.runner.model is not None, "model is not loaded" + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.NONE, ): + self.runner.model + }) + + def maybe_initialize_cudagraph(self): + # This is called only after attention backend is initialized. + + if self.compilation_config.level == CompilationLevel.PIECEWISE\ + and len(self.compilation_config.splitting_ops)>0: + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.PIECEWISE, ): + self.runner.model + }) + logger.debug("Piecewise cudagraph initialized") + + if self.runner.full_cuda_graph: + attn_cg = self.runner.attn_metadata_builders[0].\ + attn_cudagraph_support + # create full cudagraph for mix prefill-decode/general batches + if attn_cg in [AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE] and \ + self.runner.capture_mixed_batches: + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): + CUDAGraphWrapper( + self.runner.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/mixed")) + }) + logger.debug("Full cudagraph for mixed batches initialized") + # create full cudagraph for pure decode batches. + if self.compilation_config.separate_attention_routine: + self.cudagraph_candidates.update({ + (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): + CUDAGraphWrapper( + self.runner.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/pure-decode")) + }) + logger.debug( + "Full cudagraph for pure decode batches initialized") + + def get_cudagraph_runtime_style(self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa + + if self.cudagraph_mode == CUDAGraphMode.NONE: + return CUDAGraphRuntimeStyle.NONE + + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + # safe to direct return as we have already checked + # CUDAGraphMode.PIECEWISE is compatible only when + # enable vllm compilation. + return CUDAGraphRuntimeStyle.PIECEWISE + + # Otherwise, for modes that enable full cudagraph. + + # Some attention backends only support CUDA Graphs in pure decode. + # If attention doesn't support CUDA Graphs for this batch, we skip them, + # and turn back to the piecewise CUDA graphs. + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\ + attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE + + # PIECEWISE would fall back to NONE if no compilation + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ + self.no_compilation: + cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE + + #TODO: can we optimize above logic? + return cudagraph_runtime_style + + def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, + is_pure_decode: bool) -> Any: + # if no cudagraph candidates inside other platforms, + # just skip cudagraph dispatching. + if not self.cudagraph_candidates: + logger.warning_once("cudagraphs are not initialized." + " No cudagraph will be used.") + return self.runner.model + + # select between no cudagraph and piecewise cudagraph + if cudagraph_runtime_style in [ + CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE + ]: + selected_model = self.cudagraph_candidates.get( + (cudagraph_runtime_style, ), None) + assert selected_model is not None, ("cudagraph_candidates is not" + " correctly initialized for key: " + f"({cudagraph_runtime_style}, ).") + else: + # for full cudagraph, select between general batches + # or pure decode batches + decode_case = (DECODE_BOOLEN,) if self.compilation_config.\ + separate_attention_routine and is_pure_decode \ + else (not DECODE_BOOLEN,) + tuple_key = (cudagraph_runtime_style, ) + decode_case + selected_model = self.cudagraph_candidates.get(tuple_key, None) + assert selected_model is not None, ("cudagraph_candidates is not" + " correctly initialized for key: " + f"({cudagraph_runtime_style}, " + f"{is_pure_decode}).") + return selected_model \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 48aad6037fb..62efaf8c450 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -70,6 +70,7 @@ from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from ..sample.logits_processor import LogitsProcessorManager from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, @@ -89,8 +90,6 @@ logger = init_logger(__name__) -# constant code pure decode -DECODE_BOOLEN = True class GPUModelRunner(LoRAModelRunnerMixin): @@ -222,7 +221,7 @@ def __init__( ) self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.use_cuda_graph = (self.cudagraph_mode > CUDAGraphMode.NONE + self.use_cuda_graph = (self.cudagraph_mode != CUDAGraphMode.NONE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -322,12 +321,15 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - # Dict to store cudagraph candidates for later runtime dispatching. - self.cudagraph_candidates: dict[tuple, Any] = {} - # if we want to only capture pure decode batches - self.skip_capture_general_batches = False + # We may disable capturing cudagraph for mixed batches when + # no support (e.g., no piecewise compilation) or want only capturing + # full cudagraph for pure decode batches. + self.capture_mixed_batches = True self.no_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self, self.vllm_config) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -1366,8 +1368,8 @@ def execute_model( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - cudagraph_runtime_style = self._cudagraph_runtime_style( - attention_cuda_graphs) + cudagraph_runtime_style = self.cudagraph_dispatcher.\ + get_cudagraph_runtime_style(attention_cuda_graphs) # Note: When cudagraph_mode is FULL and # compilation_config.separate_attention_routine is True, as in FA2, # this flag helps to determine the correct routine for the full @@ -1381,12 +1383,14 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_style=cudagraph_runtime_style),\ - self.cudagraph_dispatch(cudagraph_runtime_style, - is_pure_decode): + cudagraph_runtime_style=cudagraph_runtime_style): self.maybe_setup_kv_connector(scheduler_output) - model_output = self.model( + model = self.cudagraph_dispatcher.dispatch( + cudagraph_runtime_style, + is_pure_decode) + + model_output = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1827,13 +1831,8 @@ def load_model(self) -> None: self.device, self.parallel_config, ) - # Immediately add self.model to cudagraph_candidates - # for profile run. - # Note that self.model always support no cudagraph. - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.NONE, ): - self.model - }) + # immediately initialize the dispatcher for profile run + self.cudagraph_dispatcher.after_load_model() def save_tensorized_model( self, @@ -1992,7 +1991,7 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: Union[bool, Literal["auto"]] = False, + cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE, # noqa is_pure_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, @@ -2016,19 +2015,18 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - # This lets FA2 to correctly activate the optimized routine for - # pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA. + # If separate_attention_routine for attention backend is enabled when + # use full cudagraph, we need to manually activate the correct routine + # for mixed prefill-decode batches and pure decode batches separately + # during capturing. + # For example, below code switches to the optimized routine of FA2 + # for pure decoding, i.e., Flashdecode + an optimization for GQA/MQA. max_query_len = 1 if is_pure_decode else num_tokens attn_metadata: Optional[dict[str, Any]] = None - cudagraph_runtime_style = CUDAGraphRuntimeStyle.PIECEWISE if \ - not self.no_compilation else CUDAGraphRuntimeStyle.NONE - - if capture_attn_cudagraph: - # Note: At this step, `capture_attn_cudagraph` should be True or - # "auto", but we always treat it as "auto". i.e., always let the - # attention backends to determine whether to capture the attention - # or not. + + + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] @@ -2046,30 +2044,15 @@ def _dummy_run( num_actual_tokens=num_tokens, max_query_len=max_query_len, ) - # If all attention backends can run in a cudagraph, we use a full - # cudagraph for attention. Otherwise, back to piecewise cudagraphs. - attention_cuda_graphs = all( - b.can_run_in_cudagraph(common_attn_metadata) - for b in self.attn_metadata_builders) - cudagraph_runtime_style = self._cudagraph_runtime_style( - attention_cuda_graphs) - - if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i - else: - attn_metadata = None # reset to None other than empty dict + + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): - if is_profile: - # when profiling, _maybe_initialize_cudagraph() is not called, - # so always run no cudagraph. - cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -2102,10 +2085,10 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_style=cudagraph_runtime_style), \ - self.cudagraph_dispatch( - cudagraph_runtime_style, is_pure_decode): - outputs = self.model( + cudagraph_runtime_style=cudagraph_runtime_style): + model = self.cudagraph_dispatcher.dispatch( + cudagraph_runtime_style, is_pure_decode) + outputs = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -2339,8 +2322,8 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + "ensure `cudagraph_mode` was not manually set to %s", + CUDAGraphMode.NONE) return compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2352,13 +2335,17 @@ def capture_model(self) -> None: # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): - full_cg = self.full_cuda_graph - - if not self.skip_capture_general_batches: - # If full_cuda_graph is true, automatically determine whether - # or not to capture the attention for the mix prefill-decode - # phase, based on the attention backends. - capture_attn_cg_general = "auto" if full_cg else False + if self.capture_mixed_batches: + # select between full cudagraph and piecewise cudagraph + # for mixed prefill-decode batches. + attn_cuda_graphs = False if self.cudagraph_mode == \ + CUDAGraphMode.PIECEWISE else ( + self.attn_metadata_builders[0].attn_cudagraph_support in [ + AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE, + ]) + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if \ + attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE # Skip capturing batch sizes of 1 in mix prefill-decode if # separate_attention_routine is on. As bs=1 can treat as a @@ -2378,23 +2365,23 @@ def capture_model(self) -> None: compilation_cases = tqdm( list(compilation_cases), desc="Capturing CUDA graphs (mix prefill-decode)") - # Capture the mix prefill-decode (general usage) cudagraphs for num_tokens in compilation_cases: for _ in range( self.compilation_config.cudagraph_num_of_warmups): + # use CUDAGraphRuntimeStyle.NONE (default) for warmup self._dummy_run( num_tokens, - capture_attn_cudagraph=capture_attn_cg_general, is_pure_decode=False, skip_eplb=True) self._dummy_run( num_tokens, - capture_attn_cudagraph=capture_attn_cg_general, + cudagraph_runtime_style=cudagraph_runtime_style, is_pure_decode=False, skip_eplb=True) if self.compilation_config.separate_attention_routine: - # Capture the pure decode cudagraphs. Typically a full cudagraph + # Capture full cudagraph for pure decode. + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL max_num_reqs = self.scheduler_config.max_num_seqs decode_cudagraph_batch_sizes = [ @@ -2407,20 +2394,18 @@ def capture_model(self) -> None: list(compilation_cases_decode), desc="Capturing CUDA graphs (pure decode)") - for num_tokens in tqdm( - reversed(decode_cudagraph_batch_sizes), - desc="Capturing CUDA graphs (pure decode)", - total=len(decode_cudagraph_batch_sizes)): + for num_tokens in compilation_cases_decode: for _ in range( self.compilation_config.cudagraph_num_of_warmups): + # use CUDAGraphRuntimeStyle.NONE (default) for warmup self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, is_pure_decode=True, skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - is_pure_decode=True, - skip_eplb=True) + self._dummy_run( + num_tokens, + cudagraph_runtime_style=cudagraph_runtime_style, + is_pure_decode=True, + skip_eplb=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2430,95 +2415,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def _maybe_initialize_cudagraph(self): - - if self.compilation_config.level == CompilationLevel.PIECEWISE\ - and len(self.compilation_config.splitting_ops)>0: - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.PIECEWISE, ): - self.model - }) - logger.debug("Piecewise cudagraph initialized") - - if self.full_cuda_graph: - attn_cg = self.attn_metadata_builders[0].attn_cudagraph_support - # create full cudagraph for mix prefill-decode/general batches - if attn_cg == AttentionCGSupport.ALWAYS: - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): - CUDAGraphWrapper( - self.model, - self.vllm_config, - runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_specific_config={"usage_type": "general"}) - }) - logger.debug("Full cudagraph for mixed batches initialized") - # create full cudagraph for pure decode batches - if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY or \ - (attn_cg == AttentionCGSupport.ALWAYS and \ - self.compilation_config.separate_attention_routine): - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): - CUDAGraphWrapper( - self.model, - self.vllm_config, - runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_specific_config={"usage_type": "decode"}) - }) - logger.debug( - "Full cudagraph for pure decode batches initialized") - - def _cudagraph_runtime_style(self, attn_cuda_graphs): - - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, we skip them, - # and turn back to the piecewise CUDA graphs. - cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\ - attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE - cudagraph_runtime_style = min(self.cudagraph_mode, - cudagraph_runtime_style) - - # PIECEWISE would fall back to NONE if no compilation - if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ - self.no_compilation: - cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE - - #TODO: can we optimize above logic? - return cudagraph_runtime_style - - @contextmanager - def cudagraph_dispatch(self, cudagraph_runtime_style: int, - is_pure_decode: bool): - # if no cudagraph candidates inside other platforms, - # just skip cudagraph dispatching. - if not self.cudagraph_candidates: - logger.warning_once("cudagraphs are not initialized." - " No cudagraph will be used.") - yield - return - - old_model = self.model - # select between no cudagraph and piecewise cudagraph - if cudagraph_runtime_style in [ - CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE - ]: - self.model = self.cudagraph_candidates.get( - (cudagraph_runtime_style, ), None) - else: - # for full cudagraph, select between general batches - # or pure decode batches - decode_case = (DECODE_BOOLEN,) if self.compilation_config.\ - separate_attention_routine and is_pure_decode \ - else (not DECODE_BOOLEN,) - tuple_key = (cudagraph_runtime_style, ) + decode_case - self.model = self.cudagraph_candidates.get(tuple_key, None) - assert self.model is not None, ("cudagraph_candidates is not " - "correctly initialized for" - f"({cudagraph_runtime_style}, " - f"{is_pure_decode})") - yield - self.model = old_model - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2563,42 +2459,66 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: ) if self.full_cuda_graph: - attn_cg: int = attn_metadata_builder_i.attn_cudagraph_support - if not attn_cg > 0: + attn_cg: AttentionCGSupport = \ + attn_metadata_builder_i.attn_cudagraph_support + if attn_cg == AttentionCGSupport.NEVER: raise ValueError( f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off " - f"CompilationConfig.full_cuda_graph or use a different" + f"{attn_backend_i.__name__}. Set " + f"CompilationConfig.cudagraph_mode to `NONE` " + f"or `PIECEWISE`, or use a different" f" attention backend.") if len(self.compilation_config.splitting_ops) == 0: - assert attn_cg == AttentionCGSupport.ALWAYS, ( + assert attn_cg in [ + AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE, + ], ( f"Full CUDAGraph not supported for " f"{attn_backend_i.__name__} with " f"CompilationConfig.splitting_ops = []. " f"Set it to None (default values) " f"or use a different attention backend.") - # check if the attention backends enforce to have separate - # routines for mix prefill-decode and pure decode phase - if attn_metadata_builder_i.prefer_separate_routine is not None \ - and self.compilation_config.separate_attention_routine\ - != attn_metadata_builder_i.prefer_separate_routine: + # check if the attention backends compatible with + # CompilationConfig.separate_attention_routine + is_updated = False + expected = False + if attn_cg == AttentionCGSupport.ALWAYS_UNIFIED and \ + self.compilation_config.separate_attention_routine: + expected = False + is_updated = True + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ + not self.compilation_config.separate_attention_routine: + expected = True + is_updated = True - expected = attn_metadata_builder_i.prefer_separate_routine + if is_updated: logger.warning_once( f"Full CUDAGraph for {attn_backend_i.__name__}" f"expect CompilationConfig.separate_attention" f"_rountine as: {expected}. Now set it to: " f"{expected}.") - self.compilation_config.separate_attention_routine = \ expected + # when AttentionCGSupport.ALWAYS_SEPARATE, we don't change + # the separate_attention_routine flag, but should inform + # the user that this flag can be turned on to obtain + # better performance. + if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \ + not self.compilation_config.separate_attention_routine: + logger.warning_once( + f"Full CUDAGraph for {attn_backend_i.__name__} " + f"supports capturing separate attention routine " + f"for pure decode and mix prefill-decode batches. " + f"You can turn on CompilationConfig.separate_" + f"attention_routine to obtain better performance.") + # for attn_cg is pure decode only, and no compilation, # we skip capturing mix prefill-decode (general) batches. if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ self.no_compilation: - self.skip_capture_general_batches = True + self.capture_mixed_batches = False self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) @@ -2606,7 +2526,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # Trigger cudagraph initialization here (after # initializing attn backends). # TODO: move this to better place. - self._maybe_initialize_cudagraph() + self.cudagraph_dispatcher.maybe_initialize_cudagraph() def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: From dc455ee2895fa29242056331a8ffce6d69d71093 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 10 Jul 2025 01:44:19 +0000 Subject: [PATCH 20/33] cleanup Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 20 ++++---- vllm/compilation/cuda_graph.py | 13 ++--- vllm/compilation/piecewise_backend.py | 8 ++- vllm/compilation/wrapper.py | 3 +- vllm/config.py | 6 +-- vllm/forward_context.py | 15 +++--- vllm/platforms/cuda.py | 6 +-- vllm/platforms/interface.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 8 +-- vllm/v1/attention/backends/flashinfer.py | 4 +- vllm/v1/attention/backends/utils.py | 4 +- vllm/v1/cudagraph_dispatcher.py | 64 ++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 29 +++++------ 13 files changed, 87 insertions(+), 95 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 845368326d8..82879e98c1f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,8 +15,8 @@ from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs -from vllm.config import (CompilationConfig, VllmConfig, CUDAGraphMode, - CUDAGraphRuntimeStyle) +from vllm.config import (CompilationConfig, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname @@ -277,7 +277,6 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs - compilation_start_time = 0.0 @@ -338,20 +337,20 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) # Lazy import here to avoid circular import - from .piecewise_backend import PiecewiseBackend from .cuda_graph import CUDAGraphOptions - + from .piecewise_backend import PiecewiseBackend + piecewise_backend = PiecewiseBackend( submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_general_shape, self.vllm_backend) - + compiled_graph_for_dynamic_shape, self.vllm_backend) + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: # resolve the static graph wrapper class (e.g. CUDAGraphWrapper # class) as platform dependent. static_graph_wrapper_class = resolve_obj_by_qualname( current_platform.get_static_graph_wrapper_cls()) - + # Always assign PIECEWISE runtime style to the # CUDAGraphWrapper for piecewise_backend, to distinguish # it from the FULL cudagraph runtime style, no matter it @@ -361,12 +360,11 @@ def call_module(self, target: torch.fx.node.Target, self.vllm_config, CUDAGraphRuntimeStyle.PIECEWISE, self.graph_pool, - cudagraph_options = CUDAGraphOptions( + cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, weak_ref_output=piecewise_backend.is_last_graph, - usage_str="piecewise" - )) + usage_str="piecewise")) else: self.module.__dict__[target] = piecewise_backend diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 3369eeff3dd..1eef9bdfebc 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -13,8 +13,8 @@ from vllm.config import CUDAGraphRuntimeStyle, VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors from vllm.platforms import current_platform +from vllm.utils import weak_ref_tensors logger = init_logger(__name__) @@ -27,7 +27,7 @@ class CUDAGraphEntry: # for cudagraph debugging, track the input addresses # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None + input_addresses: Optional[list[int]] = None @dataclasses.dataclass @@ -35,7 +35,7 @@ class CUDAGraphOptions: debug_log_enable: bool = True gc_disable: bool = False weak_ref_output: bool = True - usage_str: Optional[str] = None # For debug logging only + usage_str: Optional[str] = None # For debug logging only class CUDAGraphWrapper: @@ -48,7 +48,7 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, runtime_style: CUDAGraphRuntimeStyle, - graph_pool: Any = current_platform.get_global_graph_pool(), + graph_pool: Any = None, cudagraph_options: Optional[CUDAGraphOptions] = None): self.runnable = runnable self.vllm_config = vllm_config @@ -62,12 +62,13 @@ def __init__(self, # assert runtime_style is not NONE(no cudagraph), otherwise, we don't # need to initialize a CUDAGraphWrapper. assert self.runtime_style != CUDAGraphRuntimeStyle.NONE - assert self.graph_pool is not None + if self.graph_pool is None: + self.graph_pool = current_platform.get_default_cudagraph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() self.cudagraph_options = cudagraph_options - + self.cudagraph_capture_sizes: set[int] = set( self.compilation_config.cudagraph_capture_sizes) # the entries for different shapes that we need to capture cudagraph diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 982118a114c..aad1293e317 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -2,18 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Any, Callable, Optional +from typing import Any, Callable import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile -from vllm.config import CUDAGraphRuntimeStyle, VllmConfig +from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.utils import resolve_obj_by_qualname logger = init_logger(__name__) @@ -70,7 +68,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - + # We only keep compilation management inside this class directly. for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 2f2349474ea..34ce1c4f25c 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,8 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel, get_current_vllm_config, CUDAGraphMode +from vllm.config import (CompilationLevel, CUDAGraphMode, + get_current_vllm_config) from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/config.py b/vllm/config.py index ab74099b57d..4ab8918c558 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3922,7 +3922,7 @@ class CompilationLevel: class CUDAGraphMode(enum.Enum): - # constants for the config of the cudagraph mode. + # constants for the config of the cudagraph mode. NONE = 0 PIECEWISE = 1 FULL = 2 @@ -4094,8 +4094,8 @@ class CompilationConfig: # CudaGraph compilation cudagraph_mode: CUDAGraphMode = field( - default_factory=lambda: CUDAGraphMode.PIECEWISE if envs.VLLM_USE_V1 - else CUDAGraphMode.NONE) + default_factory=lambda: CUDAGraphMode.PIECEWISE + if envs.VLLM_USE_V1 else CUDAGraphMode.NONE) """ The mode of the cudagraph. - NONE, no cudagraph capture. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 5850fd2003c..0de5f1fc93a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -112,14 +112,13 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE, -): +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( + CUDAGraphRuntimeStyle.NONE)): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b3104082f14..9543ba266a3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -165,21 +165,21 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - + # lazy import to avoid circular import from vllm.config import CUDAGraphMode if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 and vllm_config.compilation_config.cudagraph_mode - != CUDAGraphMode.NONE): + != CUDAGraphMode.NONE): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - + vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.NONE vllm_config.model_config.enforce_eager = True # TODO (varun): Turning this ON gives incorrect results for the diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index bd81280c352..06fb7a02e13 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -521,7 +521,7 @@ def __getattr__(self, key: str): logger.warning("Current platform %s does not have '%s'" \ " attribute.", self.device_type, key) return None - + def get_global_graph_pool(self) -> Any: """ Return the global graph pool for the this platform. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3f9eae960f7..f5e5e98e0a3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -22,7 +22,7 @@ get_scheduler_metadata, reshape_and_cache_flash) -from vllm.config import VllmConfig, get_layers_from_vllm_config, CUDAGraphMode +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( @@ -155,7 +155,7 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): + AttentionMetadataBuilder[FlashAttentionMetadata]): # FA2 launches separte routines for prefill-decode and pure decode batches, # while FA3 launches a unified varlen fwd kernel for both prefill-decode # and pure decode batches. @@ -180,7 +180,7 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - + self.use_full_cuda_graph = ( compilation_config.cudagraph_mode == CUDAGraphMode.FULL) @@ -348,7 +348,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # output buffer. self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - + if num_actual_tokens <= self.max_cudagraph_size: # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 75aa3996898..e14450e1a40 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -237,8 +237,8 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, self._decode_wrapper = None # Wrapper for decode (general shape) compilation_config = self.vllm_config.compilation_config - self.enable_cuda_graph = (compilation_config.cudagraph_mode == - CUDAGraphMode.FULL) + self.enable_cuda_graph = ( + compilation_config.cudagraph_mode == CUDAGraphMode.FULL) if self.enable_cuda_graph: # For full cudagraph capture, one `decode_wrapper` for each batch # size is needed for FlashInfer. diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3ef030fed98..0d903fd8ce5 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc +import enum import functools from abc import abstractmethod from dataclasses import dataclass -import enum -from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar import numpy as np import torch diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index d3348dcc549..8323556555a 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Any -from vllm.config import (CUDAGraphRuntimeStyle, VllmConfig, CompilationLevel, - CUDAGraphMode) + +from vllm.compilation.cuda_graph import CUDAGraphOptions, CUDAGraphWrapper +from vllm.config import (CompilationLevel, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig) from vllm.v1.attention.backends.utils import AttentionCGSupport -from vllm.compilation.cuda_graph import CUDAGraphWrapper, CUDAGraphOptions if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -13,7 +14,6 @@ logger = init_logger(__name__) - # constant for pure decode DECODE_BOOLEN = True @@ -23,8 +23,7 @@ class CudagraphDispatcher: Runtime cudagraph dispatcher for gpu model runner. """ - def __init__(self, runner: "GPUModelRunner", - vllm_config: VllmConfig): + def __init__(self, runner: "GPUModelRunner", vllm_config: VllmConfig): self.runner = runner self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config @@ -33,7 +32,7 @@ def __init__(self, runner: "GPUModelRunner", # Dict to store cudagraph candidates for runtime dispatching. self.cudagraph_candidates: dict[tuple, Any] = {} - + def after_load_model(self): # add original model to cudagraph_candidates for profile run. assert self.runner.model is not None, "model is not loaded" @@ -62,39 +61,38 @@ def maybe_initialize_cudagraph(self): self.runner.capture_mixed_batches: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): - CUDAGraphWrapper( - self.runner.model, - self.vllm_config, - runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_options=CUDAGraphOptions( - usage_str="full/mixed")) + CUDAGraphWrapper(self.runner.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/mixed")) }) logger.debug("Full cudagraph for mixed batches initialized") # create full cudagraph for pure decode batches. if self.compilation_config.separate_attention_routine: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): - CUDAGraphWrapper( - self.runner.model, - self.vllm_config, - runtime_style=CUDAGraphRuntimeStyle.FULL, - cudagraph_options=CUDAGraphOptions( - usage_str="full/pure-decode")) + CUDAGraphWrapper(self.runner.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/pure-decode")) }) logger.debug( "Full cudagraph for pure decode batches initialized") - - def get_cudagraph_runtime_style(self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa + + def get_cudagraph_runtime_style( + self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa if self.cudagraph_mode == CUDAGraphMode.NONE: return CUDAGraphRuntimeStyle.NONE - + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: # safe to direct return as we have already checked # CUDAGraphMode.PIECEWISE is compatible only when # enable vllm compilation. return CUDAGraphRuntimeStyle.PIECEWISE - + # Otherwise, for modes that enable full cudagraph. # Some attention backends only support CUDA Graphs in pure decode. @@ -102,7 +100,7 @@ def get_cudagraph_runtime_style(self, attn_cuda_graphs: bool) -> CUDAGraphRuntim # and turn back to the piecewise CUDA graphs. cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\ attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE - + # PIECEWISE would fall back to NONE if no compilation if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ self.no_compilation: @@ -126,9 +124,10 @@ def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, ]: selected_model = self.cudagraph_candidates.get( (cudagraph_runtime_style, ), None) - assert selected_model is not None, ("cudagraph_candidates is not" - " correctly initialized for key: " - f"({cudagraph_runtime_style}, ).") + assert selected_model is not None, ( + "cudagraph_candidates is not" + " correctly initialized for key: " + f"({cudagraph_runtime_style}, ).") else: # for full cudagraph, select between general batches # or pure decode batches @@ -137,8 +136,9 @@ def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, else (not DECODE_BOOLEN,) tuple_key = (cudagraph_runtime_style, ) + decode_case selected_model = self.cudagraph_candidates.get(tuple_key, None) - assert selected_model is not None, ("cudagraph_candidates is not" - " correctly initialized for key: " - f"({cudagraph_runtime_style}, " - f"{is_pure_decode}).") - return selected_model \ No newline at end of file + assert selected_model is not None, ( + "cudagraph_candidates is not" + " correctly initialized for key: " + f"({cudagraph_runtime_style}, " + f"{is_pure_decode}).") + return selected_model diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfb727a0bbb..ab5a0d310d1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6,7 +6,7 @@ import time import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch @@ -19,7 +19,6 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import (CompilationLevel, CUDAGraphMode, CUDAGraphRuntimeStyle, VllmConfig, get_layers_from_vllm_config) @@ -53,6 +52,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) @@ -70,7 +70,6 @@ from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from ..sample.logits_processor import LogitsProcessorManager from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, @@ -91,7 +90,6 @@ logger = init_logger(__name__) - class GPUModelRunner(LoRAModelRunnerMixin): def __init__( @@ -327,7 +325,7 @@ def __init__( self.capture_mixed_batches = True self.no_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager - + # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self, self.vllm_config) @@ -1387,9 +1385,8 @@ def execute_model( cudagraph_runtime_style=cudagraph_runtime_style): self.maybe_setup_kv_connector(scheduler_output) - model = self.cudagraph_dispatcher.dispatch( - cudagraph_runtime_style, - is_pure_decode) + model = self.cudagraph_dispatcher.dispatch(cudagraph_runtime_style, + is_pure_decode) model_output = model( input_ids=input_ids, @@ -1993,7 +1990,8 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, - cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE, # noqa + cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( + CUDAGraphRuntimeStyle.NONE), is_pure_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, @@ -2026,7 +2024,6 @@ def _dummy_run( max_query_len = 1 if is_pure_decode else num_tokens attn_metadata: Optional[dict[str, Any]] = None - if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: attn_metadata = {} @@ -2046,7 +2043,7 @@ def _dummy_run( num_actual_tokens=num_tokens, max_query_len=max_query_len, ) - + for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): @@ -2373,10 +2370,9 @@ def capture_model(self) -> None: for _ in range( self.compilation_config.cudagraph_num_of_warmups): # use CUDAGraphRuntimeStyle.NONE (default) for warmup - self._dummy_run( - num_tokens, - is_pure_decode=False, - skip_eplb=True) + self._dummy_run(num_tokens, + is_pure_decode=False, + skip_eplb=True) self._dummy_run( num_tokens, cudagraph_runtime_style=cudagraph_runtime_style, @@ -2477,8 +2473,7 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: assert attn_cg in [ AttentionCGSupport.ALWAYS_UNIFIED, AttentionCGSupport.ALWAYS_SEPARATE, - ], ( - f"Full CUDAGraph not supported for " + ], (f"Full CUDAGraph not supported for " f"{attn_backend_i.__name__} with " f"CompilationConfig.splitting_ops = []. " f"Set it to None (default values) " From 620a728d49a703114427a4985d3bd64517d9951c Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Thu, 10 Jul 2025 10:52:27 +0800 Subject: [PATCH 21/33] fix warmup Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/cuda_graph.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 1eef9bdfebc..c316ebcf6b1 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -63,7 +63,7 @@ def __init__(self, # need to initialize a CUDAGraphWrapper. assert self.runtime_style != CUDAGraphRuntimeStyle.NONE if self.graph_pool is None: - self.graph_pool = current_platform.get_default_cudagraph_pool() + self.graph_pool = current_platform.get_global_graph_pool() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ab5a0d310d1..5807584300a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1992,6 +1992,7 @@ def _dummy_run( num_tokens: int, cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( CUDAGraphRuntimeStyle.NONE), + force_attention: bool = False, is_pure_decode: bool = False, skip_eplb: bool = False, is_profile: bool = False, @@ -2025,7 +2026,10 @@ def _dummy_run( attn_metadata: Optional[dict[str, Any]] = None - if cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL: + # If force_attention is True, we always capture attention. Otherwise, + # it depends on the cudagraph_runtime_style to be FULL or PIECEWISE. + if force_attention or cudagraph_runtime_style == \ + CUDAGraphRuntimeStyle.FULL: attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] @@ -2369,9 +2373,14 @@ def capture_model(self) -> None: for num_tokens in compilation_cases: for _ in range( self.compilation_config.cudagraph_num_of_warmups): - # use CUDAGraphRuntimeStyle.NONE (default) for warmup + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. self._dummy_run(num_tokens, is_pure_decode=False, + force_attention=attn_cuda_graphs, skip_eplb=True) self._dummy_run( num_tokens, @@ -2397,9 +2406,10 @@ def capture_model(self) -> None: for num_tokens in compilation_cases_decode: for _ in range( self.compilation_config.cudagraph_num_of_warmups): - # use CUDAGraphRuntimeStyle.NONE (default) for warmup + # Always force attention for warmup of pure decode. self._dummy_run(num_tokens, is_pure_decode=True, + force_attention=True, skip_eplb=True) self._dummy_run( num_tokens, From b1e6978d088b7dafb7cae154f7d670822b813f7e Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 10 Jul 2025 11:44:24 +0800 Subject: [PATCH 22/33] Commit suggestion: Update vllm/config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4ab8918c558..a235f1c5944 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4116,10 +4116,9 @@ class CompilationConfig: Currently, the cudagraph mode is only used for the v1 engine. Note that the cudagraph logic is generally orthogonal to the - compilation logic. For piecewise cudagraph, the logic is kept - inside the compilation. Meanwhile, the full cudagraph is captured - outside the compilation, and it further supports cudagraph - without compilation. + compilation logic. While piecewise cudagraphs require piecewise + compilation (level=PIECEWISE and non-empty splitting_ops), full + cudagraphs are supported with and without compilation. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. From beee69a6d4bdd4a14cfbaa61d0cafe6372c498b4 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 10 Jul 2025 11:45:53 +0800 Subject: [PATCH 23/33] commit suggestion2: Update vllm/config.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Luka Govedič Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index a235f1c5944..087493be36c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4137,11 +4137,11 @@ class CompilationConfig: internally managed buffer. Default is False.""" separate_attention_routine: bool = False """ - Enable a distinct attention calls routine under an attention backend for + Enable distinct attention routines for mixed and pure-decode batches during full cuda graph capturing. This is because some attention backends like - FlashMLA, FlashInfer, FA2, etc. implement different branches for mix - prefill-decode and pure decode cases. This flag enables us to potentially - capture the cudagraph separately for each branch. + FlashMLA, FlashInfer, FA2, etc. implement different branches for mixed + prefill-decode and pure decode cases. This flag enables capturing separate + cudagraphs for each branch. """ pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" From 21b1a8dcdff0f3f16462390f5a421ee7f5a26623 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 10 Jul 2025 06:15:19 +0000 Subject: [PATCH 24/33] fix enforce_eager Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 087493be36c..a361467ff77 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4651,6 +4651,10 @@ def __post_init__(self): if self.compilation_config.level is None: self.compilation_config.level = CompilationLevel.NO_COMPILATION + # disable cudagraph if enforce eager execution + if self.model_config is not None and self.model_config.enforce_eager: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + self._set_cudagraph_sizes() if self.cache_config.cpu_offload_gb > 0 and \ From 210359af600211649535bd88dd75970e5d1919c0 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 10 Jul 2025 07:44:59 +0000 Subject: [PATCH 25/33] small cleanup for pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/platforms/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 787a4ec4fbd..ceaeeb61fb3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -168,7 +168,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # lazy import to avoid circular import from vllm.config import CUDAGraphMode - + if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 and vllm_config.compilation_config.cudagraph_mode From 699aff307df3d99b6fae2cc517e2bc1340ff11cf Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 13 Jul 2025 09:35:39 +0000 Subject: [PATCH 26/33] refactors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/compilation/backends.py | 8 +- vllm/compilation/base_static_graph.py | 17 +-- vllm/compilation/cuda_graph.py | 21 ++- vllm/config.py | 1 + vllm/platforms/cuda.py | 11 +- vllm/platforms/interface.py | 7 +- vllm/v1/attention/backends/flashinfer.py | 6 +- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/attention/backends/utils.py | 19 +-- vllm/v1/cudagraph_dispatcher.py | 56 ++++---- vllm/v1/worker/cpu_model_runner.py | 1 + vllm/v1/worker/gpu_model_runner.py | 164 +++++++++++++---------- 12 files changed, 167 insertions(+), 146 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 82879e98c1f..fd82951dcf3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -356,10 +356,10 @@ def call_module(self, target: torch.fx.node.Target, # it from the FULL cudagraph runtime style, no matter it # is wrapped on a full or piecewise fx graph. self.module.__dict__[target] = static_graph_wrapper_class( - piecewise_backend, - self.vllm_config, - CUDAGraphRuntimeStyle.PIECEWISE, - self.graph_pool, + runnable=piecewise_backend, + vllm_config=self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.PIECEWISE, + graph_pool=self.graph_pool, cudagraph_options=CUDAGraphOptions( debug_log_enable=piecewise_backend.is_first_graph, gc_disable=not piecewise_backend.is_first_graph, diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 9e1e9477051..ae98603fc6c 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum from typing import Any, Callable, Protocol -from vllm.config import VllmConfig +from vllm.config import CUDAGraphRuntimeStyle, VllmConfig class AbstractStaticGraphWrapper(Protocol): @@ -14,7 +13,8 @@ class AbstractStaticGraphWrapper(Protocol): """ def __init__(self, runnable: Callable, vllm_config: VllmConfig, - graph_pool: Any, runtime_style: enum.Enum, **kwargs): + runtime_style: CUDAGraphRuntimeStyle, graph_pool: Any, + **kwargs): """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -22,11 +22,11 @@ def __init__(self, runnable: Callable, vllm_config: VllmConfig, Args: runnable (Callable): The callable to be wrapped and captured. vllm_config (VllmConfig): Global configuration for vLLM. + runtime_style (CUDAGraphRuntimeStyle): The style of the static + graph runtime. See CUDAGraphRuntimeStyle in vllm/config.py. graph_pool (Any): Graph memory pool handle, e.g., `torch.cuda.graph_pool_handle()`. - runtime_style (enum.Enum): The style of the static - graph runtime. e.g. see CUDAGraphRuntimeStyle in vllm/config.py. Keyword Args: kwargs: Additional keyword arguments for platform-specific configurations. @@ -37,9 +37,10 @@ def __call__(self, *args, **kwargs) -> Any: """ Executes the wrapped callable. - This may involve replaying a captured static graph if the conditions - are met, or running the original callable eagerly and potentially - capturing it. + If the current CUDAGraphRuntimeStyle in the ForwardContext + matches the runtime style of this instance, it replays the CUDAGraph + or captures it using the callable if it hasn't been captured yet. + Otherwise, it calls the original callable directly. Args: *args: Variable length input arguments to be passed into the diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index c316ebcf6b1..313eebeb24a 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -85,17 +85,14 @@ def __call__(self, *args, **kwargs): if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\ runtime_shape is None: - # make sure it's on profile run, eager run, or warmup stage. + # This could mean the profile run, a warmup run, or running + # without cudagraphs. return self.runnable(*args, **kwargs) if cudagraph_runtime_style != self.runtime_style: # Only triggers capture/replay if the runtime style matches, - # otherwise, we fallback to the original runnable to handle - # no match case. This is a hack to avoid double capturing - # cudagraph and ensure extra safety in situations where we - # have nested CUDAdGraphWrapper structure, e.g., we have - # piecewise cudagraph for piecewise backend, which may be - # further wrapped to obtain a full cudagraph. See #20059 for - # more details. + # otherwise, we fallback to the original runnable. + # This enables properly dispatching to the correct CUDAGraphWrapper + # when nesting multiple instances with different runtime styles. return self.runnable(*args, **kwargs) if runtime_shape not in self.concrete_cudagraph_entries: @@ -108,7 +105,8 @@ def __call__(self, *args, **kwargs): if self.cudagraph_options.debug_log_enable: # Since we capture cudagraph for many different shapes and # capturing is fast, we don't need to log it for every - # shape. We only log it in the debug mode. + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. logger.debug("Capturing a cudagraph of %s usage for shape %s", self.cudagraph_options.usage_str, entry.runtime_shape) @@ -139,8 +137,9 @@ def __call__(self, *args, **kwargs): # by converting it to weak ref, # the original `output` will immediately be released # to save memory. It is only safe to do this for - # the last graph, because the output of the last - # graph will not be used by any other cuda graph. + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. output = weak_ref_tensors(output) # here we always use weak ref for the output diff --git a/vllm/config.py b/vllm/config.py index 32f7a4b6f03..86629a8767f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4671,6 +4671,7 @@ def __post_init__(self): # disable cudagraph if enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode.") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE self._set_cudagraph_sizes() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index eb5bb5db86c..cca0eeb4d74 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -169,20 +169,19 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # lazy import to avoid circular import from vllm.config import CUDAGraphMode + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and vllm_config.compilation_config.cudagraph_mode - != CUDAGraphMode.NONE): + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - - vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - if vllm_config.model_config is not None: - vllm_config.model_config.enforce_eager = True + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + if model_config is not None: + model_config.enforce_eager = True # TODO (varun): Turning this ON gives incorrect results for the # Deepseek-V2-lite model. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 06fb7a02e13..e0a071f97d1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -526,10 +526,9 @@ def get_global_graph_pool(self) -> Any: """ Return the global graph pool for the this platform. """ - global _global_graph_pool - if _global_graph_pool is None: - _global_graph_pool = self.graph_pool_handle() - return _global_graph_pool + if not hasattr(self, '_global_graph_pool'): + self._global_graph_pool = self.graph_pool_handle() + return self._global_graph_pool @classmethod def get_cu_count(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index a63b5ca94da..d242d694a7a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -472,8 +472,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): self._num_decodes, attn_metadata.max_seq_len, attn_metadata.kv_data_type, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): - # TODO: Override flashinfer's plan function to avoid some - # host-to-device copy overhead. + # TODO: Override flashinfer's plan function to avoid some + # host-to-device copy overhead. attn_metadata.decode_wrapper.plan( # NOTE: Use the persistent buffer with padding length, # instead of the same address but chunked length buffers @@ -629,7 +629,7 @@ def build_for_cudagraph_capture( "FlashInfer only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - m.max_query_len = 1 # decode-only + assert m.max_query_len == 1 # decode-only # Update state usually set in reorder_batch. self._num_decodes = m.num_reqs diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 970de229e13..eb23fac7f11 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -595,7 +595,7 @@ def build_for_cudagraph_capture( "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - m.max_query_len = 1 # decode-only + assert m.max_query_len == 1 # decode-only # Update state usually set in reorder_batch. self._num_decodes = m.num_reqs diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 50dff0f2c1e..531883d42af 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -53,19 +53,20 @@ class CommonAttentionMetadata: class AttentionCGSupport(enum.Enum): - # Constants for the cudagraph support of the attention backend - # Here we do not consider the cascade attention, as currently - # it is never cudagraph supported. + """ Constants for the cudagraph support of the attention backend + Here we do not consider the cascade attention, as currently + it is never cudagraph supported.""" - NEVER = 0 # No support + NEVER = 0 + """NO cudagraph support""" PURE_DECODE_ONLY = 1 - # Cudagraph supported for pure decode, need to use piecewise - # cudagraph or no cudagraph for mixed prefill-decode batches + """Cudagraph supported for pure decode, need to use piecewise + cudagraph or no cudagraph for mixed prefill-decode batches""" ALWAYS_UNIFIED = 2 - # Cudagraph always supported with unified routine + """Cudagraph always supported with unified routine""" ALWAYS_SEPARATE = 3 - # Cudagraph supported for both mixed prefill-decode - # or pure decode attention routines. + """ Cudagraph supported for both mixed prefill-decode + or pure decode attention routines.""" class AttentionMetadataBuilder(abc.ABC, Generic[M]): diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 8323556555a..a61f35bd5cd 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any +from typing import Any, Callable from vllm.compilation.cuda_graph import CUDAGraphOptions, CUDAGraphWrapper from vllm.config import (CompilationLevel, CUDAGraphMode, CUDAGraphRuntimeStyle, VllmConfig) -from vllm.v1.attention.backends.utils import AttentionCGSupport - -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - from vllm.logger import init_logger logger = init_logger(__name__) @@ -20,59 +15,64 @@ class CudagraphDispatcher: """ - Runtime cudagraph dispatcher for gpu model runner. + Runtime cudagraph dispatcher to switch between multiple cudagraphs. """ - def __init__(self, runner: "GPUModelRunner", vllm_config: VllmConfig): - self.runner = runner + def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.no_compilation = self.runner.no_compilation + self.no_compilation = self.compilation_config.level != \ + CompilationLevel.PIECEWISE or \ + vllm_config.model_config.enforce_eager + + self.model: Callable = None # type: ignore + # we lazy initialize self.model after model loading of model + # runner have been done. # Dict to store cudagraph candidates for runtime dispatching. self.cudagraph_candidates: dict[tuple, Any] = {} - def after_load_model(self): + def after_load_model(self, model: Callable): # add original model to cudagraph_candidates for profile run. - assert self.runner.model is not None, "model is not loaded" + assert model is not None, "model should not be None" + self.model = model self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.NONE, ): - self.runner.model + self.model }) - def maybe_initialize_cudagraph(self): - # This is called only after attention backend is initialized. + def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): + assert self.model is not None, ( + "No model have been assigned to cudagraph dispatcher") + # This should be called only after attention backend is initialized. if self.compilation_config.level == CompilationLevel.PIECEWISE\ and len(self.compilation_config.splitting_ops)>0: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.PIECEWISE, ): - self.runner.model + self.model }) logger.debug("Piecewise cudagraph initialized") - if self.runner.full_cuda_graph: - attn_cg = self.runner.attn_metadata_builders[0].\ - attn_cudagraph_support + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL: # create full cudagraph for mix prefill-decode/general batches - if attn_cg in [AttentionCGSupport.ALWAYS_UNIFIED, - AttentionCGSupport.ALWAYS_SEPARATE] and \ - self.runner.capture_mixed_batches: + if create_mixed_batch_full_cg: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): - CUDAGraphWrapper(self.runner.model, + CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, cudagraph_options=CUDAGraphOptions( usage_str="full/mixed")) }) logger.debug("Full cudagraph for mixed batches initialized") - # create full cudagraph for pure decode batches. + # always create full cudagraph for pure decode batches if speparate + # attention routine. if self.compilation_config.separate_attention_routine: self.cudagraph_candidates.update({ (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): - CUDAGraphWrapper(self.runner.model, + CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, cudagraph_options=CUDAGraphOptions( @@ -111,12 +111,14 @@ def get_cudagraph_runtime_style( def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, is_pure_decode: bool) -> Any: - # if no cudagraph candidates inside other platforms, + assert self.model is not None, ("No model have been assigned" + "to cudagraph dispatcher") + # if no cudagraph candidates, # just skip cudagraph dispatching. if not self.cudagraph_candidates: logger.warning_once("cudagraphs are not initialized." " No cudagraph will be used.") - return self.runner.model + return self.model # select between no cudagraph and piecewise cudagraph if cudagraph_runtime_style in [ diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466..f96d1f64892 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -61,6 +61,7 @@ def load_model(self) -> None: self.model = self.load_lora_model(self.model, self.model_config, self.scheduler_config, self.lora_config, self.device) + self.cudagraph_dispatcher.after_load_model(self.model) def warming_up_model(self) -> None: logger.info("Warming up model for the compilation...") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 68ce58a90f7..a66655516a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -218,8 +218,6 @@ def __init__( ) self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.use_cuda_graph = (self.cudagraph_mode != CUDAGraphMode.NONE - and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -321,12 +319,12 @@ def __init__( # We may disable capturing cudagraph for mixed batches when # no support (e.g., no piecewise compilation) or want only capturing # full cudagraph for pure decode batches. - self.capture_mixed_batches = True + self.capture_mixed_batches = self.cudagraph_mode != CUDAGraphMode.NONE self.no_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager # Cudagraph dispatcher for runtime cudagraph dispatching. - self.cudagraph_dispatcher = CudagraphDispatcher(self, self.vllm_config) + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -1301,9 +1299,9 @@ def execute_model( spec_decode_metadata, num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph + if (self.cudagraph_mode != CUDAGraphMode.NONE and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. + # Use CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) @@ -1792,7 +1790,7 @@ def load_model(self) -> None: self.parallel_config, ) # immediately initialize the dispatcher for profile run - self.cudagraph_dispatcher.after_load_model() + self.cudagraph_dispatcher.after_load_model(self.model) def save_tensorized_model( self, @@ -2286,11 +2284,10 @@ def profile_run(self) -> None: gc.collect() def capture_model(self) -> None: - if not self.use_cuda_graph: + if self.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to %s", - CUDAGraphMode.NONE) + "ensure `cudagraph_mode` was not manually set to `NONE`") return compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2322,34 +2319,12 @@ def capture_model(self) -> None: and len(self.cudagraph_batch_sizes) > 0 \ and self.cudagraph_batch_sizes[0] == 1: start_idx = 1 - - # We skip EPLB here since we don't want to record dummy metrics - - # Only rank 0 should print progress bar during capture - compilation_cases = reversed( - self.cudagraph_batch_sizes[start_idx:]) - if is_global_first_rank(): - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graphs (mix prefill-decode)") - for num_tokens in compilation_cases: - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to - # if we want to warm up attention or not. This is - # different from the case where `FULL` implies capture - # attention while `PIECEWISE` implies no attention. - self._dummy_run(num_tokens, - is_pure_decode=False, - force_attention=attn_cuda_graphs, - skip_eplb=True) - self._dummy_run( - num_tokens, - cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=False, - skip_eplb=True) + compilation_cases = list( + reversed(self.cudagraph_batch_sizes[start_idx:])) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_style=cudagraph_runtime_style, + is_pure_decode=False) if self.compilation_config.separate_attention_routine: # Capture full cudagraph for pure decode. @@ -2359,27 +2334,12 @@ def capture_model(self) -> None: decode_cudagraph_batch_sizes = [ x for x in self.cudagraph_batch_sizes if x <= max_num_reqs ] - compilation_cases_decode = reversed( - decode_cudagraph_batch_sizes) - if is_global_first_rank(): - compilation_cases_decode = tqdm( - list(compilation_cases_decode), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graphs (pure decode)") - - for num_tokens in compilation_cases_decode: - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - # Always force attention for warmup of pure decode. - self._dummy_run(num_tokens, - is_pure_decode=True, - force_attention=True, - skip_eplb=True) - self._dummy_run( - num_tokens, - cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=True, - skip_eplb=True) + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_style=cudagraph_runtime_style, + is_pure_decode=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2389,6 +2349,35 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_style: CUDAGraphRuntimeStyle, + is_pure_decode: bool): + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({})".format( + "pure decode" if is_pure_decode else "mix prefill-decode")) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = cudagraph_runtime_style == CUDAGraphMode.FULL + self._dummy_run(num_tokens, + is_pure_decode=is_pure_decode, + cudagraph_runtime_style=CUDAGraphMode.NONE, + force_attention=force_attention, + skip_eplb=True) + self._dummy_run(num_tokens, + cudagraph_runtime_style=cudagraph_runtime_style, + is_pure_decode=False, + skip_eplb=True) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2396,6 +2385,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: assert len(self.attn_backends) == 0 and len( self.attn_metadata_builders ) == 0, "Attention backends are already initialized" + + # Record the attention cudagraph support of the first spec. + attn_cg: AttentionCGSupport = None # type: ignore for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2432,9 +2424,18 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: block_table_i, ) - if self.full_cuda_graph: - attn_cg: AttentionCGSupport = \ - attn_metadata_builder_i.attn_cudagraph_support + if self.cudagraph_mode == CUDAGraphMode.FULL: + if attn_cg is None: + attn_cg = attn_metadata_builder_i.attn_cudagraph_support + else: + if attn_cg != attn_metadata_builder_i.\ + attn_cudagraph_support: + raise ValueError( + "All attention backends must have the same " + "AttentionCGSupport type when using full " + "CUDAGraph. Set CompilationConfig.cudagraph_mode" + " to `PIECEWISE` instead.") + if attn_cg == AttentionCGSupport.NEVER: raise ValueError( f"Full CUDAGraph not supported for " @@ -2455,25 +2456,30 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # check if the attention backends compatible with # CompilationConfig.separate_attention_routine - is_updated = False - expected = False if attn_cg == AttentionCGSupport.ALWAYS_UNIFIED and \ self.compilation_config.separate_attention_routine: expected = False - is_updated = True + logger.warning_once( + f"Full CUDAGraph for {attn_backend_i.__name__} " + f"supports unified attention routine for mixed " + f"batches or pure decode batches, which expect " + f"CompilationConfig.separate_attention_rountine" + f" as: {expected}. Now set it to: {expected}.") + self.compilation_config.separate_attention_routine = \ + expected + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ not self.compilation_config.separate_attention_routine: expected = True - is_updated = True - - if is_updated: logger.warning_once( - f"Full CUDAGraph for {attn_backend_i.__name__}" - f"expect CompilationConfig.separate_attention" - f"_rountine as: {expected}. Now set it to: " - f"{expected}.") + f"Full CUDAGraph for {attn_backend_i.__name__} " + f"requires separate attention routines for mixed " + f"batches or pure decode batches, which expect " + f"CompilationConfig.separate_attention_rountine" + f" as: {expected}. Now set it to: {expected}.") self.compilation_config.separate_attention_routine = \ expected + # when AttentionCGSupport.ALWAYS_SEPARATE, we don't change # the separate_attention_routine flag, but should inform # the user that this flag can be turned on to obtain @@ -2491,6 +2497,11 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # we skip capturing mix prefill-decode (general) batches. if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ self.no_compilation: + logger.warning_once( + f"Skipping capturing mixed prefill-decode batches, " + f"since full cudagraph for {attn_backend_i.__name__}" + f"only supports pure decode batches while piecewise " + f"cudagraph is disabled as no vllm compilation.") self.capture_mixed_batches = False self.attn_backends.append(attn_backend_i) @@ -2499,7 +2510,14 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # Trigger cudagraph initialization here (after # initializing attn backends). # TODO: move this to better place. - self.cudagraph_dispatcher.maybe_initialize_cudagraph() + + # if we need capture full cudagraph for mixed prefill-decode batches. + create_mixed_batch_full_cg = attn_cg in [ + AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE] and \ + self.capture_mixed_batches + self.cudagraph_dispatcher.maybe_initialize_cudagraph( + create_mixed_batch_full_cg) def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: From ef3d9d96b3d4c9bd4cbd0e4958c9abbf6147cfe5 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 13 Jul 2025 10:44:17 +0000 Subject: [PATCH 27/33] resolve yapf conflicts with isort Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/attention/backends/flashinfer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index d242d694a7a..754bda39a95 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -19,6 +19,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention +# yapf conflicts with isort for this block +# yapf: disable from vllm.v1.attention.backends.utils import (AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, @@ -26,6 +28,7 @@ get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters) +# yapf: enable from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable From 658565ee1360156107caabfd33e91fc7be51deee Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 13 Jul 2025 15:10:42 +0000 Subject: [PATCH 28/33] fixes Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/platforms/interface.py | 7 +++++-- vllm/v1/worker/gpu_model_runner.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index e0a071f97d1..572fa134772 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -30,8 +30,6 @@ SamplingParams = None FlexibleArgumentParser = None -_global_graph_pool = None - logger = init_logger(__name__) @@ -138,6 +136,8 @@ class Platform: additional_env_vars: list[str] = [] + _global_graph_pool: Optional[Any] = None + @property def supported_dtypes(self) -> list[torch.dtype]: """Returns the supported dtypes for the current platform.""" @@ -514,6 +514,9 @@ def validate_request( """Raises if this request is unsupported on this platform""" def __getattr__(self, key: str): + if hasattr(self, key): + return getattr(self, key) + device = getattr(torch, self.device_type, None) if device is not None and hasattr(device, key): return getattr(device, key) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a66655516a4..7980078769a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2367,12 +2367,14 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_style == CUDAGraphMode.FULL - self._dummy_run(num_tokens, - is_pure_decode=is_pure_decode, - cudagraph_runtime_style=CUDAGraphMode.NONE, - force_attention=force_attention, - skip_eplb=True) + force_attention = ( + cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL) + self._dummy_run( + num_tokens, + is_pure_decode=is_pure_decode, + cudagraph_runtime_style=CUDAGraphRuntimeStyle.NONE, + force_attention=force_attention, + skip_eplb=True) self._dummy_run(num_tokens, cudagraph_runtime_style=cudagraph_runtime_style, is_pure_decode=False, From 15e2b4a6ebfc2dd0a7509e04764fff9a7b240746 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 13 Jul 2025 16:33:58 +0000 Subject: [PATCH 29/33] fix global graph pool issue Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/platforms/interface.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 572fa134772..bdde1356d89 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -514,9 +514,6 @@ def validate_request( """Raises if this request is unsupported on this platform""" def __getattr__(self, key: str): - if hasattr(self, key): - return getattr(self, key) - device = getattr(torch, self.device_type, None) if device is not None and hasattr(device, key): return getattr(device, key) @@ -529,9 +526,10 @@ def get_global_graph_pool(self) -> Any: """ Return the global graph pool for the this platform. """ - if not hasattr(self, '_global_graph_pool'): - self._global_graph_pool = self.graph_pool_handle() - return self._global_graph_pool + cls = type(self) + if cls._global_graph_pool is None: + cls._global_graph_pool = self.graph_pool_handle() + return cls._global_graph_pool @classmethod def get_cu_count(cls, device_id: int = 0) -> int: From 4253dbf8d9ccc98ced58f112cac957c70d02072c Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Sun, 13 Jul 2025 16:49:37 +0000 Subject: [PATCH 30/33] fix refactors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7980078769a..6b5df54967e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2377,7 +2377,7 @@ def _capture_cudagraphs(self, compilation_cases: list[int], skip_eplb=True) self._dummy_run(num_tokens, cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=False, + is_pure_decode=is_pure_decode, skip_eplb=True) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: From 1b5496203bd3cb0c5ab6a16ce55049f1b96b9ace Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Mon, 14 Jul 2025 07:54:25 +0000 Subject: [PATCH 31/33] more refactors Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 10 ++- vllm/platforms/interface.py | 2 +- vllm/v1/attention/backends/utils.py | 4 +- vllm/v1/cudagraph_dispatcher.py | 98 ++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 66 ++++++++++--------- 5 files changed, 90 insertions(+), 90 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 5795d3e1d2d..172821f8a50 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4451,10 +4451,8 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention_with_output", ] elif len(self.splitting_ops) == 0: - assert self.cudagraph_mode == CUDAGraphMode.FULL, ( - "Seting splitting_ops as empty list requires " - "cudagraph_mode be CUDAGraphMode.FULL") - + assert self.cudagraph_mode != CUDAGraphMode.PIECEWISE, ( + "Cannot use piecewise CUDAGraph without splitting ops.") self.splitting_ops = [] @@ -4724,8 +4722,8 @@ def __post_init__(self): self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: - # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph - # is set to True, full CUDA graphs will be used. + # By default, V1 uses piecewise CUDA graphs. If cudagraph_mode + # is set to `FULL`, full CUDA graphs will be used. self.compilation_config.cudagraph_num_of_warmups = 1 if self.compilation_config.level is None: self.compilation_config.level = CompilationLevel.PIECEWISE diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 874a1f10159..881f069f923 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -527,7 +527,7 @@ def get_global_graph_pool(self) -> Any: """ Return the global graph pool for the this platform. """ - cls = type(self) + cls = self.__class__ if cls._global_graph_pool is None: cls._global_graph_pool = self.graph_pool_handle() return cls._global_graph_pool diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 531883d42af..3b589268bec 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -65,8 +65,8 @@ class AttentionCGSupport(enum.Enum): ALWAYS_UNIFIED = 2 """Cudagraph always supported with unified routine""" ALWAYS_SEPARATE = 3 - """ Cudagraph supported for both mixed prefill-decode - or pure decode attention routines.""" + """Cudagraph always supported, with better performance when separate + routines are used for mixed prefill-decode and pure decode batches.""" class AttentionMetadataBuilder(abc.ABC, Generic[M]): diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index a61f35bd5cd..927b01f26ab 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Callable +from typing import Any, Callable, NamedTuple, Optional from vllm.compilation.cuda_graph import CUDAGraphOptions, CUDAGraphWrapper from vllm.config import (CompilationLevel, CUDAGraphMode, @@ -9,8 +9,15 @@ logger = init_logger(__name__) -# constant for pure decode -DECODE_BOOLEN = True + +class DispatchKey(NamedTuple): + """ + Key for dispatching cudagraphs. + """ + cudagraph_runtime_style: CUDAGraphRuntimeStyle + # Be aware that is_pure_decode should be default None + # for both piecewise cudagraphs and no cudagraphs. + is_pure_decode: Optional[bool] = None class CudagraphDispatcher: @@ -22,25 +29,28 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.no_compilation = self.compilation_config.level != \ - CompilationLevel.PIECEWISE or \ - vllm_config.model_config.enforce_eager self.model: Callable = None # type: ignore - # we lazy initialize self.model after model loading of model + # we lazy initialize self.model once the model loading of # runner have been done. # Dict to store cudagraph candidates for runtime dispatching. - self.cudagraph_candidates: dict[tuple, Any] = {} + self.cudagraph_candidates: dict[DispatchKey, Any] = {} + + # Verify if correctly piecewise compilation for attention. + piecewise_compilation = not vllm_config.model_config.enforce_eager\ + and self.compilation_config.level == CompilationLevel.PIECEWISE + self.piecewise_attn_compilation = piecewise_compilation and\ + all(op in self.compilation_config.splitting_ops for op in [ + "vllm.unified_attention", "vllm.unified_attention_with_output"]) def after_load_model(self, model: Callable): # add original model to cudagraph_candidates for profile run. assert model is not None, "model should not be None" self.model = model - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.NONE, ): - self.model - }) + self.cudagraph_candidates.update( + {DispatchKey(CUDAGraphRuntimeStyle.NONE): self.model}) + logger.debug("Cudagraph candidates for NONE style initialized") def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): assert self.model is not None, ( @@ -49,17 +59,15 @@ def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): if self.compilation_config.level == CompilationLevel.PIECEWISE\ and len(self.compilation_config.splitting_ops)>0: - self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.PIECEWISE, ): - self.model - }) + self.cudagraph_candidates.update( + {DispatchKey(CUDAGraphRuntimeStyle.PIECEWISE): self.model}) logger.debug("Piecewise cudagraph initialized") if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL: # create full cudagraph for mix prefill-decode/general batches if create_mixed_batch_full_cg: self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.FULL, not DECODE_BOOLEN): + DispatchKey(CUDAGraphRuntimeStyle.FULL, False): CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, @@ -71,7 +79,7 @@ def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): # attention routine. if self.compilation_config.separate_attention_routine: self.cudagraph_candidates.update({ - (CUDAGraphRuntimeStyle.FULL, DECODE_BOOLEN): + DispatchKey(CUDAGraphRuntimeStyle.FULL, True): CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, @@ -94,27 +102,24 @@ def get_cudagraph_runtime_style( return CUDAGraphRuntimeStyle.PIECEWISE # Otherwise, for modes that enable full cudagraph. + assert self.cudagraph_mode == CUDAGraphMode.FULL + # If attention backend supports full cudagraphs for current batch, + # run with full cudagraphs. + if attn_cuda_graphs: + return CUDAGraphRuntimeStyle.FULL + + # Fall back to piecewise cudagraphs if possible + if self.piecewise_attn_compilation: + return CUDAGraphRuntimeStyle.PIECEWISE - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, we skip them, - # and turn back to the piecewise CUDA graphs. - cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if\ - attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE - - # PIECEWISE would fall back to NONE if no compilation - if cudagraph_runtime_style == CUDAGraphRuntimeStyle.PIECEWISE and \ - self.no_compilation: - cudagraph_runtime_style = CUDAGraphRuntimeStyle.NONE - - #TODO: can we optimize above logic? - return cudagraph_runtime_style + # Otherwise, fall back to running entirely without cudagraphs + return CUDAGraphRuntimeStyle.NONE def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, is_pure_decode: bool) -> Any: assert self.model is not None, ("No model have been assigned" "to cudagraph dispatcher") - # if no cudagraph candidates, - # just skip cudagraph dispatching. + # if no cudagraph candidates, just skip dispatching. if not self.cudagraph_candidates: logger.warning_once("cudagraphs are not initialized." " No cudagraph will be used.") @@ -124,23 +129,16 @@ def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, if cudagraph_runtime_style in [ CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE ]: - selected_model = self.cudagraph_candidates.get( - (cudagraph_runtime_style, ), None) - assert selected_model is not None, ( - "cudagraph_candidates is not" - " correctly initialized for key: " - f"({cudagraph_runtime_style}, ).") + dispatchkey = DispatchKey(cudagraph_runtime_style) + selected_model = self.cudagraph_candidates.get(dispatchkey, None) else: - # for full cudagraph, select between general batches + # for full cudagraph, select between mixed batches # or pure decode batches - decode_case = (DECODE_BOOLEN,) if self.compilation_config.\ - separate_attention_routine and is_pure_decode \ - else (not DECODE_BOOLEN,) - tuple_key = (cudagraph_runtime_style, ) + decode_case - selected_model = self.cudagraph_candidates.get(tuple_key, None) - assert selected_model is not None, ( - "cudagraph_candidates is not" - " correctly initialized for key: " - f"({cudagraph_runtime_style}, " - f"{is_pure_decode}).") + decode_case = self.compilation_config.separate_attention_routine\ + and is_pure_decode + dispatchkey = DispatchKey(cudagraph_runtime_style, decode_case) + selected_model = self.cudagraph_candidates.get(dispatchkey, None) + assert selected_model is not None, ( + f"cudagraph_candidates is not correctly initialized for key: " + f"{dispatchkey}") return selected_model diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d87a4b7859f..823ee5dc531 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -320,7 +320,7 @@ def __init__( # no support (e.g., no piecewise compilation) or want only capturing # full cudagraph for pure decode batches. self.capture_mixed_batches = self.cudagraph_mode != CUDAGraphMode.NONE - self.no_compilation = self.compilation_config.level != \ + self.no_piecewise_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager # Cudagraph dispatcher for runtime cudagraph dispatching. @@ -2312,12 +2312,10 @@ def capture_model(self) -> None: if self.capture_mixed_batches: # select between full cudagraph and piecewise cudagraph # for mixed prefill-decode batches. - attn_cuda_graphs = False if self.cudagraph_mode == \ - CUDAGraphMode.PIECEWISE else ( - self.attn_metadata_builders[0].attn_cudagraph_support in [ - AttentionCGSupport.ALWAYS_UNIFIED, - AttentionCGSupport.ALWAYS_SEPARATE, - ]) + attn_cuda_graphs = self.cudagraph_mode == CUDAGraphMode.FULL \ + and self.attn_metadata_builders[0].attn_cudagraph_support \ + in [AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE] cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if \ attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE @@ -2381,9 +2379,9 @@ def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL) self._dummy_run( num_tokens, - is_pure_decode=is_pure_decode, cudagraph_runtime_style=CUDAGraphRuntimeStyle.NONE, force_attention=force_attention, + is_pure_decode=is_pure_decode, skip_eplb=True) self._dummy_run(num_tokens, cudagraph_runtime_style=cudagraph_runtime_style, @@ -2470,27 +2468,24 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: # CompilationConfig.separate_attention_routine if attn_cg == AttentionCGSupport.ALWAYS_UNIFIED and \ self.compilation_config.separate_attention_routine: - expected = False logger.warning_once( - f"Full CUDAGraph for {attn_backend_i.__name__} " - f"supports unified attention routine for mixed " - f"batches or pure decode batches, which expect " - f"CompilationConfig.separate_attention_rountine" - f" as: {expected}. Now set it to: {expected}.") + f"Full CUDAGraph support for {attn_backend_i.__name__}" + f" is {AttentionCGSupport.ALWAYS_UNIFIED}, which expect" + f"CompilationConfig.separate_attention_rountine as " + f"False. Set it to False now.") self.compilation_config.separate_attention_routine = \ - expected + False if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ not self.compilation_config.separate_attention_routine: - expected = True + logger.warning_once( - f"Full CUDAGraph for {attn_backend_i.__name__} " - f"requires separate attention routines for mixed " - f"batches or pure decode batches, which expect " - f"CompilationConfig.separate_attention_rountine" - f" as: {expected}. Now set it to: {expected}.") + f"Full CUDAGraph support for {attn_backend_i.__name__}" + f" is {AttentionCGSupport.PURE_DECODE_ONLY}, which " + f"expect CompilationConfig.separate_attention_rountine" + f"as True. Set it to True now.") self.compilation_config.separate_attention_routine = \ - expected + True # when AttentionCGSupport.ALWAYS_SEPARATE, we don't change # the separate_attention_routine flag, but should inform @@ -2505,16 +2500,25 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: f"You can turn on CompilationConfig.separate_" f"attention_routine to obtain better performance.") - # for attn_cg is pure decode only, and no compilation, + # for attn_cg is pure decode only, and no piecewise compilation, # we skip capturing mix prefill-decode (general) batches. - if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ - self.no_compilation: - logger.warning_once( - f"Skipping capturing mixed prefill-decode batches, " - f"since full cudagraph for {attn_backend_i.__name__}" - f"only supports pure decode batches while piecewise " - f"cudagraph is disabled as no vllm compilation.") - self.capture_mixed_batches = False + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY: + if self.no_piecewise_compilation: + logger.warning_once( + f"Skipping capturing mixed prefill-decode batches, " + f"since backend {attn_backend_i.__name__} only " + f"supports full cudagraph for pure decode only and " + f"vllm piecewise compilation is no enabled.") + self.capture_mixed_batches = False + else: + assert all(op in self.compilation_config.splitting_ops + for op in ["vllm.unified_attention", + "vllm.unified_attention_with_output"]),\ + "Invalid splitting_ops for piecewise compilation" + "with cudagraph_mode `FULL` for backend " + f"{attn_backend_i.__name__}, which support " + "cudagraph on pure decode only. Please include " + "attention ops in compilation_config.splitting_ops" self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) From d6269bd2694b5636ee6e77b0dc622f25e1d23565 Mon Sep 17 00:00:00 2001 From: fhl <2410591650@qq.com> Date: Thu, 17 Jul 2025 23:59:55 +0800 Subject: [PATCH 32/33] refactors for and more Signed-off-by: fhl <2410591650@qq.com> --- vllm/compilation/backends.py | 2 +- ...e_backend.py => cuda_piecewise_backend.py} | 0 vllm/config.py | 68 +++++++- vllm/v1/cudagraph_dispatcher.py | 47 +++--- vllm/v1/worker/gpu_model_runner.py | 147 ++++++++++-------- 5 files changed, 173 insertions(+), 91 deletions(-) rename vllm/compilation/{piecewise_backend.py => cuda_piecewise_backend.py} (100%) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a2f1175ef49..e48d38fe3dc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -339,7 +339,7 @@ def call_module(self, target: torch.fx.node.Target, runtime_shape=None) # Lazy import here to avoid circular import from .cuda_graph import CUDAGraphOptions - from .piecewise_backend import PiecewiseBackend + from .cuda_piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( submod, self.vllm_config, index, diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py similarity index 100% rename from vllm/compilation/piecewise_backend.py rename to vllm/compilation/cuda_piecewise_backend.py diff --git a/vllm/config.py b/vllm/config.py index 6d4fc37e06f..bc841dee913 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4034,13 +4034,17 @@ class CompilationConfig: - [`custom_ops`][vllm.config.CompilationConfig.custom_ops] - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - CudaGraph capture: + - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] - [`cudagraph_num_of_warmups`] [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_separate_routine`] + [vllm.config.CompilationConfig.cudagraph_separate_routine] - [`cudagraph_copy_inputs`] [vllm.config.CompilationConfig.cudagraph_copy_inputs] + - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] - Inductor compilation: - [`use_inductor`][vllm.config.CompilationConfig.use_inductor] - [`compile_sizes`][vllm.config.CompilationConfig.compile_sizes] @@ -4101,6 +4105,8 @@ class CompilationConfig: splitting_ops: Optional[list[str]] = None """A list of ops to split the full graph into subgraphs, used in piecewise compilation.""" + is_attention_splitting: bool = False + """A flag to indicate if the splitting_ops contains all attention ops.""" # Inductor capture use_inductor: bool = True @@ -4155,6 +4161,19 @@ class CompilationConfig: compilation (level=PIECEWISE and non-empty splitting_ops), full cudagraphs are supported with and without compilation. """ + use_cudagraph: Optional[bool] = None + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + In the vLLM V1 Engine, this flag only applies for + CompilationLevel.PIECEWISE (aka -O3). + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + Warning: This flag is deprecated and will be removed in future releases. + Please use cudagraph_mode instead. + """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -4169,8 +4188,18 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False.""" - separate_attention_routine: bool = False + internally managed buffer. Default is False. + Note that this flag is only effective when cudagraph_mode is PIECEWISE. + """ + full_cuda_graph: Optional[bool] = None + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models. + Warning: This flag is deprecated and will be removed in future releases. + Please use cudagraph_mode instead. + """ + cudagraph_separate_routine: bool = False """ Enable distinct attention routines for mixed and pure-decode batches during full cuda graph capturing. This is because some attention backends like @@ -4385,9 +4414,9 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if self.separate_attention_routine: + if self.cudagraph_separate_routine: assert self.cudagraph_mode == CUDAGraphMode.FULL, ( - "separate_attention_routine requires " + "cudagraph_separate_routine requires " "cudagraph_mode be CUDAGraphMode.FULL") if self.splitting_ops is None: @@ -4405,6 +4434,8 @@ def set_splitting_ops_for_v1(self): assert self.cudagraph_mode != CUDAGraphMode.PIECEWISE, ( "Cannot use piecewise CUDAGraph without splitting ops.") self.splitting_ops = [] + self.is_attention_splitting = all(op in self.splitting_ops for op in [ + "vllm.unified_attention", "vllm.unified_attention_with_output"]) @config @@ -4680,6 +4711,35 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() + if self.compilation_config.use_cudagraph is not None: + logger.warning( + "`use_cudagraph` is deprecated and will be removed in the " + "future release. Switch to use `cudagraph_mode` instead.") + if self.compilation_config.use_cudagraph: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + self.compilation_config.cudagraph_mode =\ + CUDAGraphMode.PIECEWISE + # otherwise, keep the cudagraph_mode as is + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.compilation_config.full_cuda_graph is not None: + logger.warning( + "`full_cuda_graph` is deprecated and will be removed in the " + "future release. Switch to use `cudagraph_mode` instead.") + if self.compilation_config.use_cudagraph is not None and \ + self.compilation_config.full_cuda_graph: + assert self.compilation_config.use_cudagraph, ( + "`use_cudagraph` must be True when `full_cuda_graph` " + "is True.") + self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL + self.compilation_config.use_cudagraph = True + elif self.compilation_config.full_cuda_graph: + self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL + self.compilation_config.use_cudagraph = True + # other cases, keep the cudagraph_mode as is + + # For V0 or other cases, default to level 0 with no compilation if self.compilation_config.level is None: self.compilation_config.level = CompilationLevel.NO_COMPILATION diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 927b01f26ab..fa4a6484a62 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -10,14 +10,14 @@ logger = init_logger(__name__) -class DispatchKey(NamedTuple): +class CudagraphKey(NamedTuple): """ Key for dispatching cudagraphs. """ cudagraph_runtime_style: CUDAGraphRuntimeStyle - # Be aware that is_pure_decode should be default None + # Be aware that uniform_batch should be default None # for both piecewise cudagraphs and no cudagraphs. - is_pure_decode: Optional[bool] = None + uniform_batch: Optional[bool] = None class CudagraphDispatcher: @@ -35,21 +35,20 @@ def __init__(self, vllm_config: VllmConfig): # runner have been done. # Dict to store cudagraph candidates for runtime dispatching. - self.cudagraph_candidates: dict[DispatchKey, Any] = {} + self.cudagraph_candidates: dict[CudagraphKey, Any] = {} # Verify if correctly piecewise compilation for attention. piecewise_compilation = not vllm_config.model_config.enforce_eager\ and self.compilation_config.level == CompilationLevel.PIECEWISE self.piecewise_attn_compilation = piecewise_compilation and\ - all(op in self.compilation_config.splitting_ops for op in [ - "vllm.unified_attention", "vllm.unified_attention_with_output"]) + self.compilation_config.is_attention_splitting def after_load_model(self, model: Callable): # add original model to cudagraph_candidates for profile run. assert model is not None, "model should not be None" self.model = model self.cudagraph_candidates.update( - {DispatchKey(CUDAGraphRuntimeStyle.NONE): self.model}) + {CudagraphKey(CUDAGraphRuntimeStyle.NONE): self.model}) logger.debug("Cudagraph candidates for NONE style initialized") def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): @@ -60,14 +59,14 @@ def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): if self.compilation_config.level == CompilationLevel.PIECEWISE\ and len(self.compilation_config.splitting_ops)>0: self.cudagraph_candidates.update( - {DispatchKey(CUDAGraphRuntimeStyle.PIECEWISE): self.model}) + {CudagraphKey(CUDAGraphRuntimeStyle.PIECEWISE): self.model}) logger.debug("Piecewise cudagraph initialized") if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL: # create full cudagraph for mix prefill-decode/general batches if create_mixed_batch_full_cg: self.cudagraph_candidates.update({ - DispatchKey(CUDAGraphRuntimeStyle.FULL, False): + CudagraphKey(CUDAGraphRuntimeStyle.FULL, False): CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, @@ -75,19 +74,19 @@ def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): usage_str="full/mixed")) }) logger.debug("Full cudagraph for mixed batches initialized") - # always create full cudagraph for pure decode batches if speparate - # attention routine. - if self.compilation_config.separate_attention_routine: + # always create full cudagraph for uniform batches if cudagraph + # separate routine is enabled. + if self.compilation_config.cudagraph_separate_routine: self.cudagraph_candidates.update({ - DispatchKey(CUDAGraphRuntimeStyle.FULL, True): + CudagraphKey(CUDAGraphRuntimeStyle.FULL, True): CUDAGraphWrapper(self.model, self.vllm_config, runtime_style=CUDAGraphRuntimeStyle.FULL, cudagraph_options=CUDAGraphOptions( - usage_str="full/pure-decode")) + usage_str="full/uniform")) }) logger.debug( - "Full cudagraph for pure decode batches initialized") + "Full cudagraph for uniform batches initialized") def get_cudagraph_runtime_style( self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa @@ -116,7 +115,7 @@ def get_cudagraph_runtime_style( return CUDAGraphRuntimeStyle.NONE def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, - is_pure_decode: bool) -> Any: + uniform_batch: bool) -> Any: assert self.model is not None, ("No model have been assigned" "to cudagraph dispatcher") # if no cudagraph candidates, just skip dispatching. @@ -129,16 +128,16 @@ def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, if cudagraph_runtime_style in [ CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE ]: - dispatchkey = DispatchKey(cudagraph_runtime_style) - selected_model = self.cudagraph_candidates.get(dispatchkey, None) + key = CudagraphKey(cudagraph_runtime_style) + selected_model = self.cudagraph_candidates.get(key, None) else: # for full cudagraph, select between mixed batches - # or pure decode batches - decode_case = self.compilation_config.separate_attention_routine\ - and is_pure_decode - dispatchkey = DispatchKey(cudagraph_runtime_style, decode_case) - selected_model = self.cudagraph_candidates.get(dispatchkey, None) + # or uniform batches + uniform_batch = uniform_batch and\ + self.compilation_config.cudagraph_separate_routine + key = CudagraphKey(cudagraph_runtime_style, uniform_batch) + selected_model = self.cudagraph_candidates.get(key, None) assert selected_model is not None, ( f"cudagraph_candidates is not correctly initialized for key: " - f"{dispatchkey}") + f"{key}") return selected_model diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6a6046aff7f..875d905ffe7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -319,6 +319,9 @@ def __init__( self.capture_mixed_batches = self.cudagraph_mode != CUDAGraphMode.NONE self.no_piecewise_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager + + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -769,7 +772,8 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata) + spec_decode_common_attn_metadata, + max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1291,7 +1295,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata) = ( + spec_decode_common_attn_metadata, max_query_len) = ( self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.cudagraph_mode != CUDAGraphMode.NONE @@ -1359,10 +1363,11 @@ def execute_model( cudagraph_runtime_style = self.cudagraph_dispatcher.\ get_cudagraph_runtime_style(attention_cuda_graphs) # Note: When cudagraph_mode is FULL and - # compilation_config.separate_attention_routine is True, as in FA2, - # this flag helps to determine the correct routine for the full - # cudagraph. - is_pure_decode = num_scheduled_tokens == self.input_batch.num_reqs + # compilation_config.cudagraph_separate_routine is True, this + # flag helps to determine the correct cudagraph routine (optimized + # for attention ops). + uniform_batch = max_query_len == self.uniform_decode_query_len and \ + num_scheduled_tokens == self.input_batch.num_reqs*max_query_len # Run the model. # Use persistent buffers for CUDA graphs. @@ -1375,7 +1380,7 @@ def execute_model( self.maybe_setup_kv_connector(scheduler_output) model = self.cudagraph_dispatcher.dispatch(cudagraph_runtime_style, - is_pure_decode) + uniform_batch) model_output = model( input_ids=input_ids, @@ -1938,7 +1943,7 @@ def _dummy_run( cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( CUDAGraphRuntimeStyle.NONE), force_attention: bool = False, - is_pure_decode: bool = False, + uniform_batch: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1946,29 +1951,48 @@ def _dummy_run( # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad - + + # If cudagraph_separate_routine is enabled when use full cudagraph, + # we need to manually activate the correct routine of attention backend + # for mixed prefill-decode batches and uniform decode batches + # separately during capturing. Uniform batch means that all + # requests have identical query length, except possibly a single, + # shorter dummy request in the batch (account for padding when + # cudagraph capturing). + # An example of uniform batch is common pure decode, where + # max_query_len == 1. Another case is speculative decode, + # where max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_batch else \ + num_tokens + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs + if uniform_batch: + num_reqs = (num_tokens+self.uniform_decode_query_len-1) // \ + self.uniform_decode_query_len + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - # If separate_attention_routine for attention backend is enabled when - # use full cudagraph, we need to manually activate the correct routine - # for mixed prefill-decode batches and pure decode batches separately - # during capturing. - # For example, below code switches to the optimized routine of FA2 - # for pure decoding, i.e., Flashdecode + an optimization for GQA/MQA. - max_query_len = 1 if is_pure_decode else num_tokens - attn_metadata: Optional[dict[str, Any]] = None # If force_attention is True, we always capture attention. Otherwise, @@ -2040,7 +2064,7 @@ def _dummy_run( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_style=cudagraph_runtime_style): model = self.cudagraph_dispatcher.dispatch( - cudagraph_runtime_style, is_pure_decode) + cudagraph_runtime_style, uniform_batch) outputs = model( input_ids=input_ids, positions=positions, @@ -2300,10 +2324,10 @@ def capture_model(self) -> None: attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE # Skip capturing batch sizes of 1 in mix prefill-decode if - # separate_attention_routine is on. As bs=1 can treat as a - # pure decode. + # cudagraph_separate_routine is on. As bs=1 can treat as a + # uniform batch. start_idx = 0 - if self.compilation_config.separate_attention_routine \ + if self.compilation_config.cudagraph_separate_routine \ and len(self.cudagraph_batch_sizes) > 0 \ and self.cudagraph_batch_sizes[0] == 1: start_idx = 1 @@ -2312,22 +2336,24 @@ def capture_model(self) -> None: self._capture_cudagraphs( compilation_cases, cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=False) + uniform_batch=False) - if self.compilation_config.separate_attention_routine: - # Capture full cudagraph for pure decode. + if self.compilation_config.cudagraph_separate_routine: + # Capture full cudagraph for uniform batches (pure decode/ + # speculative decode). cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL - max_num_reqs = self.scheduler_config.max_num_seqs + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if x <= max_num_reqs + x for x in self.cudagraph_batch_sizes if x <= max_num_tokens ] compilation_cases_decode = list( reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=True) + uniform_batch=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2339,14 +2365,14 @@ def capture_model(self) -> None: def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_style: CUDAGraphRuntimeStyle, - is_pure_decode: bool): + uniform_batch: bool): # Only rank 0 should print progress bar during capture if is_global_first_rank(): compilation_cases = tqdm( compilation_cases, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({})".format( - "pure decode" if is_pure_decode else "mix prefill-decode")) + "decode" if uniform_batch else "mix prefill-decode")) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): @@ -2361,11 +2387,11 @@ def _capture_cudagraphs(self, compilation_cases: list[int], num_tokens, cudagraph_runtime_style=CUDAGraphRuntimeStyle.NONE, force_attention=force_attention, - is_pure_decode=is_pure_decode, + uniform_batch=uniform_batch, skip_eplb=True) self._dummy_run(num_tokens, cudagraph_runtime_style=cudagraph_runtime_style, - is_pure_decode=is_pure_decode, + uniform_batch=uniform_batch, skip_eplb=True) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -2427,11 +2453,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: if attn_cg == AttentionCGSupport.NEVER: raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Set " - f"CompilationConfig.cudagraph_mode to `NONE` " - f"or `PIECEWISE`, or use a different" - f" attention backend.") + f"Full CUDAGraph for {attn_backend_i.__name__} is " + f"no supported. Set CompilationConfig.cudagraph_mode " + f"to `NONE` or `PIECEWISE`, or use a different " + f"attention backend.") if len(self.compilation_config.splitting_ops) == 0: assert attn_cg in [ @@ -2444,40 +2469,40 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: f"or use a different attention backend.") # check if the attention backends compatible with - # CompilationConfig.separate_attention_routine + # CompilationConfig.cudagraph_separate_routine if attn_cg == AttentionCGSupport.ALWAYS_UNIFIED and \ - self.compilation_config.separate_attention_routine: + self.compilation_config.cudagraph_separate_routine: logger.warning_once( f"Full CUDAGraph support for {attn_backend_i.__name__}" f" is {AttentionCGSupport.ALWAYS_UNIFIED}, which expect" - f"CompilationConfig.separate_attention_rountine as " + f"CompilationConfig.cudagraph_separate_routine as " f"False. Set it to False now.") - self.compilation_config.separate_attention_routine = \ + self.compilation_config.cudagraph_separate_routine = \ False if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ - not self.compilation_config.separate_attention_routine: + not self.compilation_config.cudagraph_separate_routine: logger.warning_once( f"Full CUDAGraph support for {attn_backend_i.__name__}" f" is {AttentionCGSupport.PURE_DECODE_ONLY}, which " - f"expect CompilationConfig.separate_attention_rountine" + f"expect CompilationConfig.cudagraph_separate_routine" f"as True. Set it to True now.") - self.compilation_config.separate_attention_routine = \ + self.compilation_config.cudagraph_separate_routine = \ True # when AttentionCGSupport.ALWAYS_SEPARATE, we don't change - # the separate_attention_routine flag, but should inform + # the cudagraph_separate_routine flag, but should inform # the user that this flag can be turned on to obtain # better performance. if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \ - not self.compilation_config.separate_attention_routine: + not self.compilation_config.cudagraph_separate_routine: logger.warning_once( - f"Full CUDAGraph for {attn_backend_i.__name__} " - f"supports capturing separate attention routine " - f"for pure decode and mix prefill-decode batches. " - f"You can turn on CompilationConfig.separate_" - f"attention_routine to obtain better performance.") + f"{attn_backend_i.__name__} generally performs better " + f"when capturing full cudagraph for mix prefill-" + f"decode batches and pure decode batches in separate. " + f"To enable this behavior turn on " + f"CompilationConfig.cudagraph_separate_routine.") # for attn_cg is pure decode only, and no piecewise compilation, # we skip capturing mix prefill-decode (general) batches. @@ -2485,19 +2510,17 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: if self.no_piecewise_compilation: logger.warning_once( f"Skipping capturing mixed prefill-decode batches, " - f"since backend {attn_backend_i.__name__} only " + f"since backend {attn_backend_i.__name__} " f"supports full cudagraph for pure decode only and " - f"vllm piecewise compilation is no enabled.") + f"vllm piecewise compilation is disabled.") self.capture_mixed_batches = False else: - assert all(op in self.compilation_config.splitting_ops - for op in ["vllm.unified_attention", - "vllm.unified_attention_with_output"]),\ - "Invalid splitting_ops for piecewise compilation" + assert self.compilation_config.is_attention_splitting,\ + "Invalid splitting_ops for piecewise compilation " "with cudagraph_mode `FULL` for backend " f"{attn_backend_i.__name__}, which support " - "cudagraph on pure decode only. Please include " - "attention ops in compilation_config.splitting_ops" + "cudagraph only for pure decode. Please include " + "attention ops in compilation_config.splitting_ops." self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) From 2e1304cec8149083bdda78745ed8440c36fd7727 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 17 Jul 2025 16:40:02 +0000 Subject: [PATCH 33/33] fix pre-commit Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/config.py | 8 +++---- vllm/platforms/cuda.py | 4 ++-- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 28 ++++++++++++++---------- vllm/v1/cudagraph_dispatcher.py | 5 ++--- vllm/v1/worker/gpu_model_runner.py | 20 ++++++++--------- 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index bc841dee913..006b6296acb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4434,8 +4434,9 @@ def set_splitting_ops_for_v1(self): assert self.cudagraph_mode != CUDAGraphMode.PIECEWISE, ( "Cannot use piecewise CUDAGraph without splitting ops.") self.splitting_ops = [] - self.is_attention_splitting = all(op in self.splitting_ops for op in [ - "vllm.unified_attention", "vllm.unified_attention_with_output"]) + self.is_attention_splitting = all( + op in self.splitting_ops for op in + ["vllm.unified_attention", "vllm.unified_attention_with_output"]) @config @@ -4722,7 +4723,7 @@ def __post_init__(self): # otherwise, keep the cudagraph_mode as is else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE - + if self.compilation_config.full_cuda_graph is not None: logger.warning( "`full_cuda_graph` is deprecated and will be removed in the " @@ -4738,7 +4739,6 @@ def __post_init__(self): self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL self.compilation_config.use_cudagraph = True # other cases, keep the cudagraph_mode as is - # For V0 or other cases, default to level 0 with no compilation if self.compilation_config.level is None: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e3edffb197e..23ac5f73f48 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -165,7 +165,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: cache_config.block_size = 64 logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") - + use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND is not None \ and envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1") if use_cutlass_mla and cache_config.block_size != 128: @@ -175,7 +175,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # lazy import to avoid circular import from vllm.config import CUDAGraphMode - + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f46c8a3cbff..e56a0d93fea 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -186,7 +186,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL) if self.use_full_cuda_graph and self.aot_schedule: - self.max_cudagraph_size=self.compilation_config.max_capture_size + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 24f7f6faf61..5e3d6540079 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,10 +18,11 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import cdiv from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, - get_kv_cache_layout, get_per_layer_parameters, +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec @@ -233,7 +234,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - compilation_config = self.vllm_config.compilation_config + compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req self.enable_cuda_graph = ( compilation_config.cudagraph_mode == CUDAGraphMode.FULL) if self.enable_cuda_graph: @@ -242,7 +247,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self._decode_wrappers_cudagraph: dict[ int, BatchDecodeWithPagedKVCacheWrapper] = {} self._decode_cudagraph_max_bs = min( - runner.max_num_reqs, compilation_config.max_capture_size) + max_num_reqs, compilation_config.max_capture_size) self._cascade_wrapper = None # Wrapper for cascade attention @@ -254,16 +259,16 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.kv_cache_spec = kv_cache_spec # Preparing persistent buffers - self.paged_kv_indptr = torch.zeros(self.runner.max_num_reqs + 1, + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, - device=self.runner.device) + device=self.device) self.paged_kv_indices = torch.zeros( - block_table.get_device_tensor().numel(), # max num pages possible + max_num_pages, # max num pages possible dtype=torch.int32, - device=self.runner.device) - self.paged_kv_last_page_len = torch.zeros(self.runner.max_num_reqs, + device=self.device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, dtype=torch.int32, - device=self.runner.device) + device=self.device) def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -450,6 +455,7 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashInferMetadata: + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index fa4a6484a62..37d4e8da50e 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -85,8 +85,7 @@ def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): cudagraph_options=CUDAGraphOptions( usage_str="full/uniform")) }) - logger.debug( - "Full cudagraph for uniform batches initialized") + logger.debug("Full cudagraph for uniform batches initialized") def get_cudagraph_runtime_style( self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa @@ -115,7 +114,7 @@ def get_cudagraph_runtime_style( return CUDAGraphRuntimeStyle.NONE def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, - uniform_batch: bool) -> Any: + uniform_batch: bool) -> Any: assert self.model is not None, ("No model have been assigned" "to cudagraph dispatcher") # if no cudagraph candidates, just skip dispatching. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 875d905ffe7..c8a322654c8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -319,7 +319,7 @@ def __init__( self.capture_mixed_batches = self.cudagraph_mode != CUDAGraphMode.NONE self.no_piecewise_compilation = self.compilation_config.level != \ CompilationLevel.PIECEWISE or self.model_config.enforce_eager - + self.uniform_decode_query_len = 1 if not self.speculative_config else \ 1 + self.speculative_config.num_speculative_tokens @@ -589,7 +589,7 @@ def _prepare_inputs( scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata]]: + np.ndarray, Optional[CommonAttentionMetadata], int]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -772,8 +772,7 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + spec_decode_common_attn_metadata, max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1295,8 +1294,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata, max_query_len) = ( - self._prepare_inputs(scheduler_output)) + spec_decode_common_attn_metadata, + max_query_len) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.cudagraph_mode != CUDAGraphMode.NONE and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1951,7 +1950,7 @@ def _dummy_run( # Padding for DP num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad - + # If cudagraph_separate_routine is enabled when use full cudagraph, # we need to manually activate the correct routine of attention backend # for mixed prefill-decode batches and uniform decode batches @@ -1968,7 +1967,7 @@ def _dummy_run( # for GQA/MQA. max_query_len = self.uniform_decode_query_len if uniform_batch else \ num_tokens - + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. @@ -1987,7 +1986,7 @@ def _dummy_run( min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs - + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, @@ -2346,7 +2345,8 @@ def capture_model(self) -> None: max_num_tokens = self.scheduler_config.max_num_seqs * \ self.uniform_decode_query_len decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if x <= max_num_tokens + x for x in self.cudagraph_batch_sizes + if x <= max_num_tokens ] compilation_cases_decode = list( reversed(decode_cudagraph_batch_sizes))