diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5fe274f2c65b..9df44835561f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -25,7 +25,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec @@ -144,7 +145,9 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \ + else AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 94d80d441d8c..bf273e652df8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,13 +4,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional, Union import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) -from flashinfer.decode import trtllm_batch_decode_with_kv_cache +from flashinfer.decode import (_get_range_buf, get_seq_lens, + trtllm_batch_decode_with_kv_cache) import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -18,11 +19,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import cdiv +from vllm.utils import cdiv, is_pin_memory_available from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, - get_kv_cache_layout, get_per_layer_parameters, + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec @@ -219,25 +220,65 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): self.device = device + self.vllm_config = vllm_config + self.cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append - self._decode_wrapper = None # Wrapper for decode + self._decode_wrapper = None # Wrapper for decode (general shape) + + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + self.enable_cuda_graph = self.compilation_config.full_cuda_graph + if self.enable_cuda_graph: + # For full cudagraph capture, one `decode_wrapper` for each batch + # size is needed for FlashInfer. + self._decode_wrappers_cudagraph: dict[ + int, BatchDecodeWithPagedKVCacheWrapper] = {} + self._decode_cudagraph_max_bs = min( + max_num_reqs, self.compilation_config.max_capture_size) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = vllm_config - self.cache_config = vllm_config.cache_config - self.kv_cache_spec = kv_cache_spec - max_num_blocks_per_request = cdiv( - vllm_config.model_config.max_model_len, - self.kv_cache_spec.block_size) - self.block_table_arange = torch.arange(max_num_blocks_per_request, + # Preparing persistent buffers (device-side) + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.paged_kv_indices = torch.zeros( + max_num_pages, # max num pages possible + dtype=torch.int32, + device=self.device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, + dtype=torch.int32, + device=self.device) + # host-side buffer + pin_memory = is_pin_memory_available() + self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.paged_kv_indices_cpu = torch.zeros(max_num_pages, + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=pin_memory) + + self.block_table_arange = torch.arange(max_num_pages_per_req, dtype=torch.int32, device=self.device) @@ -261,8 +302,16 @@ def _get_prefill_wrapper(self): self._get_workspace_buffer(), get_kv_cache_layout()) return self._prefill_wrapper - def _get_decode_wrapper(self): - if self._decode_wrapper is None: + def _get_decode_wrapper(self, + batch_size: int, + use_cudagraph: bool = False): + if use_cudagraph: + decode_wrapper = self._decode_wrappers_cudagraph.get( + batch_size, None) + else: + decode_wrapper = self._decode_wrapper + + if decode_wrapper is None: num_qo_heads = ( self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config)) @@ -270,11 +319,32 @@ def _get_decode_wrapper(self): self.vllm_config.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + + if use_cudagraph: + paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indices = self.paged_kv_indices + paged_kv_last_page_len = self.paged_kv_last_page_len[: + batch_size] + else: + paged_kv_indptr = None + paged_kv_indices = None + paged_kv_last_page_len = None + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), get_kv_cache_layout(), + use_cuda_graph=use_cudagraph, + paged_kv_indptr_buffer=paged_kv_indptr, + paged_kv_indices_buffer=paged_kv_indices, + paged_kv_last_page_len_buffer=paged_kv_last_page_len, use_tensor_cores=use_tensor_cores) - return self._decode_wrapper + + # save the decode wrapper + if use_cudagraph: + self._decode_wrappers_cudagraph[batch_size] = decode_wrapper + else: + self._decode_wrapper = decode_wrapper + + return decode_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: @@ -356,16 +426,44 @@ def _plan(self, num_prefills: int, num_decodes: int, ) if num_decodes > 0: - attn_metadata.decode_wrapper = self._get_decode_wrapper() + pure_decode = num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = (self.enable_cuda_graph and pure_decode and + num_decodes <= self._decode_cudagraph_max_bs) + if use_cudagraph: + num_input_tokens = ( + self.vllm_config.pad_for_cudagraph(num_decodes)) + # Carefully fulfill the padding region with reasonable value + # on cpu. + # Make sure paged_kv_indptr_cpu is not decreasing + self.paged_kv_indptr_cpu[1 + num_decodes:1 + + num_input_tokens].fill_( + attn_metadata. + paged_kv_indptr_cpu[-1]) + # Fill the remaining paged_kv_last_page_len_cpu with 1. + # This is because flashinfer treats 0 as a full page + # instead of empty. + self.paged_kv_last_page_len_cpu[ + num_decodes:num_input_tokens].fill_(1) + + else: + num_input_tokens = num_decodes + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens, use_cudagraph) if not FlashInferBackend.use_trtllm_decode_attention( num_decodes, attn_metadata.max_seq_len, self.cache_config.cache_dtype, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): - attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1], + # Use the persistent buffer with padding length, + # instead of the same address but chunked version + # in atten_metadata when using cudagraph. + fast_plan_decode( + attn_metadata.decode_wrapper, + self.paged_kv_indptr_cpu[:num_input_tokens + 1], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len_cpu[:num_decodes], + self.paged_kv_last_page_len_cpu[:num_input_tokens], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -384,6 +482,7 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashInferMetadata: + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) @@ -429,18 +528,26 @@ def build(self, non_blocking=True) mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] - - paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, - dtype=torch.int32, - device='cpu') - paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( - dim=0, dtype=torch.int32) + # write self.paged_kv_indices inplace + num_actual_pages = torch.sum(mask) + paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + torch.masked_select(block_table_tensor[:, :max_num_blocks], + mask, + out=paged_kv_indices) + + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + torch.cumsum(block_table_bounds_cpu, + dim=0, + dtype=torch.int32, + out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) paged_kv_last_page_len_cpu = seq_lens_cpu % page_size - paged_kv_last_page_len_cpu = torch.where( - paged_kv_last_page_len_cpu == 0, page_size, - paged_kv_last_page_len_cpu) + # write self.paged_kv_last_page_len_cpu inplace + torch.where(paged_kv_last_page_len_cpu == 0, + torch.tensor(page_size), + paged_kv_last_page_len_cpu, + out=self.paged_kv_last_page_len_cpu[:num_reqs]) + cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -450,9 +557,10 @@ def build(self, attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, - paged_kv_indptr_cpu=paged_kv_indptr_cpu, + paged_kv_indptr_cpu=self.paged_kv_indptr_cpu[:1 + num_reqs], paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, + paged_kv_last_page_len_cpu=self. + paged_kv_last_page_len_cpu[:num_reqs], num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, @@ -480,6 +588,26 @@ def build(self, return attn_metadata + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with FlashInfer. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "FlashInfer only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 + def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting @@ -690,3 +818,163 @@ def forward( v_scale=layer._v_scale_float, )) return output_padded + + +def fast_plan_decode( + self, # decode wrapper + indptr_cpu: torch.Tensor, + indices: torch.Tensor, + last_page_len_cpu: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + q_data_type: Optional[Union[str, torch.dtype]] = "float16", + kv_data_type: Optional[Union[str, torch.dtype]] = None, + data_type: Optional[Union[str, torch.dtype]] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + non_blocking: bool = True, +) -> None: + """ + A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for + cudagraph capture/replay, while the no cudagraph version turns back + to the original plan. + using original plan after passing host-side buffers: + - only host-to-device copy of indptr and last_page_len buffers + Modifications for cudagraph: + - only host-to-device copy of indptr and last_page_len buffers. + - avoid device-to-device copy of indices buffer. + + Part of the code get inspiration from the original plan from FlashInfer repo + and the implementation of fast_decode_plan for FlashInfer in SGlang repo. + """ + # Warm up with the original plan if it is first call, and always run the + # original plan if we run for dynamic shape. For fixed shape (cudagraph), + # this warm up is to generate the _cached_module for the decode wrapper. + if not self.is_cuda_graph_enabled or \ + getattr(self, "vllm_first_call", True): + self.plan( + indptr_cpu, + indices, + last_page_len_cpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode, + window_left, + logits_soft_cap, + q_data_type, + kv_data_type, + data_type, + sm_scale, + rope_scale, + rope_theta, + non_blocking, + ) + self.vllm_first_call = False + return + + assert self.is_cuda_graph_enabled, "Should be cudagraph only here" + + batch_size = len(last_page_len_cpu) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + + # Handle data types consistently + if data_type is not None: + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type + elif q_data_type is None: + q_data_type = "float16" + + if kv_data_type is None: + kv_data_type = q_data_type + q_data_type = getattr(torch, q_data_type) if isinstance( + q_data_type, str) else q_data_type + kv_data_type = getattr(torch, kv_data_type) if isinstance( + kv_data_type, str) else kv_data_type + + if self.use_tensor_cores: + qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") + + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime " + "batch size {} mismatches the batch size set during " + "initialization {}".format(batch_size, self._fixed_batch_size)) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the " + "allocated buffer") + + # host-to-device copy for the indptr buffer + self._paged_kv_indptr_buf.copy_(indptr_cpu, non_blocking=True) + # host-to-device copy for the last_page_len buffer + self._paged_kv_last_page_len_buf.copy_(last_page_len_cpu, + non_blocking=True) + + indptr_host = indptr_cpu + last_page_len_host = last_page_len_cpu + + if self.use_tensor_cores: + kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, + page_size) + + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in tensor core plan: {e}") from e + else: + try: + # Make sure we pass exactly 15 arguments for standard version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + window_left, + logits_soft_cap, + head_dim, + head_dim, + torch.empty(0, dtype=q_data_type), + torch.empty(0, dtype=kv_data_type), + ) + except Exception as e: + raise RuntimeError(f"Error in standard plan: {e}") from e + + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6..cedb7d14d049 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -18,6 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # Decode-only + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 834c23455835..7a12acb1c3c8 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,6 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec # yapf: enable @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # decode only + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 83471ca51b73..217d0438d18e 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -18,7 +18,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec @@ -57,7 +58,8 @@ class TritonAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.ALWAYS def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index fc8649d587ee..6665cd305c2e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc +import enum import functools from abc import abstractmethod from dataclasses import dataclass @@ -63,9 +64,24 @@ class CommonAttentionMetadata: M = TypeVar("M") +class AttentionCGSupport(enum.Enum): + """ Constants for the cudagraph support of the attention backend + Here we do not consider the cascade attention, as currently + it is never cudagraph supported.""" + + NEVER = 0 + """NO cudagraph support""" + PURE_DECODE_ONLY = 1 + """Cudagraph supported for pure decode, need to run without + cudagraph for mixed prefill-decode batches""" + ALWAYS = 2 + """Cudagraph always supported""" + + class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. - full_cudagraph_supported: ClassVar[bool] = False + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a5bf197ba161..9a9c962f7e96 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -47,7 +47,7 @@ is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, @@ -2527,12 +2527,23 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.device, ) - if (self.full_cuda_graph - and not attn_metadata_builder_i.full_cudagraph_supported): - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off CompilationConfig." - f"full_cuda_graph or use a different attention backend.") + if self.full_cuda_graph: + if attn_metadata_builder_i.attn_cudagraph_support == \ + AttentionCGSupport.NEVER: + raise ValueError( + f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__}. Turn off " + f"CompilationConfig.full_cuda_graph or use a " + f" different attention backend.") + if attn_metadata_builder_i.attn_cudagraph_support == \ + AttentionCGSupport.PURE_DECODE_ONLY: + # Limit the max cudagraph size to the max number of + # sequences for pure decode only cudagraph backend, + # whose max_query_len is 1. + self.cudagraph_batch_sizes = [ + size for size in self.cudagraph_batch_sizes + if size <= self.scheduler_config.max_num_seqs + ] self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 522946351148..c41295f85b97 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -300,11 +300,16 @@ def compile_or_warm_up_model(self) -> None: if get_pp_group().is_last_rank: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) + # activate building attn_metadata for this dummy run to avoid + # potential illegal memory access for full cudagraph relay. + attn_cudagraph = self.compilation_config.full_cuda_graph and\ + not self.model_config.enforce_eager # We skip EPLB here since we don't want to record dummy metrics hidden_states, last_hidden_states = \ self.model_runner._dummy_run( num_tokens=max_num_reqs, + capture_attn_cudagraph=attn_cudagraph, skip_eplb=True, ) if self.model_runner.is_pooling_model: