diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b4e0101a0d4b..60234b4add50 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -212,7 +212,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, from vllm.v1.attention.backends.flashinfer import PerLayerParameters - def mock_get_per_layer_parameters(vllm_config): + def mock_get_per_layer_parameters(vllm_config, impl_cls): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1eb27d57acf0..e931fa9a2153 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,6 +18,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import cdiv from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, @@ -158,7 +159,7 @@ class FlashInferMetadata: # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - qo_indptr: torch.Tensor + qo_indptr_cpu: torch.Tensor # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -167,13 +168,13 @@ class FlashInferMetadata: # [0, 5, 8, 1, 6, 7, 3, 4] # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: torch.Tensor - # The page indices of the paged kv cache + # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) + paged_kv_indptr_cpu: torch.Tensor + # The page indices of the paged kv cache (on device for plan) paged_kv_indices: torch.Tensor # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: torch.Tensor + # the paged kv cache, shape: [batch_size] (CPU for plan) + paged_kv_last_page_len_cpu: torch.Tensor # The number of query/output heads num_qo_heads: int # The number of key/value heads @@ -201,22 +202,17 @@ class FlashInferMetadata: num_prefills: int num_prefill_tokens: int - # For cascade attention. + # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr: Optional[torch.Tensor] = None - shared_kv_page_indptr: Optional[torch.Tensor] = None - shared_kv_page_indices: Optional[torch.Tensor] = None - shared_kv_last_page_len: Optional[torch.Tensor] = None + shared_qo_indptr_cpu: Optional[torch.Tensor] = None + shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None + shared_kv_page_indices_cpu: Optional[torch.Tensor] = None + shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None - @property - def query_start_loc(self): - # The GPUModelRunner expects to be able to access this property. - return self.qo_indptr - def __post_init__(self): if self.head_dim is not None: FlashInferBackend.validate_head_size(self.head_dim) @@ -238,6 +234,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec + max_num_blocks_per_request = cdiv( + vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + self.block_table_arange = torch.arange(max_num_blocks_per_request, + dtype=torch.int32, + device=self.device) def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -285,21 +287,25 @@ def _plan(self, num_prefills: int, num_decodes: int, if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config, FlashInferImpl)) + if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( - [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], [ - attn_metadata.shared_kv_page_indptr, - attn_metadata.paged_kv_indptr + attn_metadata.shared_qo_indptr_cpu, + attn_metadata.qo_indptr_cpu + ], + [ + attn_metadata.shared_kv_page_indptr_cpu, + attn_metadata.paged_kv_indptr_cpu ], [ - attn_metadata.shared_kv_page_indices, + attn_metadata.shared_kv_page_indices_cpu, attn_metadata.paged_kv_indices ], [ - attn_metadata.shared_kv_last_page_len, - attn_metadata.paged_kv_last_page_len + attn_metadata.shared_kv_last_page_len_cpu, + attn_metadata.paged_kv_last_page_len_cpu ], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, @@ -320,22 +326,22 @@ def _plan(self, num_prefills: int, num_decodes: int, # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr[prefill_start:].shape[ + assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ + assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len[ + assert attn_metadata.paged_kv_last_page_len_cpu[ prefill_start:].shape[0] == num_prefills # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr = attn_metadata.qo_indptr[ - prefill_start:] - attn_metadata.qo_indptr[prefill_start] + qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ + prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] attn_metadata.prefill_wrapper.plan( - qo_indptr, - attn_metadata.paged_kv_indptr[prefill_start:], + qo_indptr_cpu, + attn_metadata.paged_kv_indptr_cpu[prefill_start:], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[prefill_start:], + attn_metadata.paged_kv_last_page_len_cpu[prefill_start:], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -356,9 +362,9 @@ def _plan(self, num_prefills: int, num_decodes: int, attn_metadata.kv_data_type, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:num_decodes + 1], + attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:num_decodes], + attn_metadata.paged_kv_last_page_len_cpu[:num_decodes], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -382,55 +388,58 @@ def build(self, split_decodes_and_prefills(common_attn_metadata) page_size = self.kv_cache_spec.block_size - device = self.device - qo_indptr = common_attn_metadata.query_start_loc max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size use_cascade = common_prefix_len > 0 if use_cascade: # Grab the blocks of the shared prefix from the first request. assert common_prefix_len % page_size == 0 num_common_kv_blocks = common_prefix_len // page_size - shared_qo_indptr = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=device) - shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device=device) - shared_kv_page_indices = block_table_tensor[ + + # Create CPU versions directly for cascade (no GPU versions needed) + shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device='cpu') + shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device='cpu') + shared_kv_page_indices_cpu = block_table_tensor[ 0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([page_size], - dtype=torch.int32, - device=device) + shared_kv_last_page_len_cpu = torch.tensor([page_size], + dtype=torch.int32, + device='cpu') + # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds -= num_common_kv_blocks + block_table_bounds_cpu -= num_common_kv_blocks else: - shared_qo_indptr = None - shared_kv_page_indptr = None - shared_kv_page_indices = None - shared_kv_last_page_len = None - - mask = (torch.arange(block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=block_table_tensor.device).unsqueeze(0) + shared_qo_indptr_cpu = None + shared_kv_page_indptr_cpu = None + shared_kv_page_indices_cpu = None + shared_kv_last_page_len_cpu = None + + max_num_blocks = block_table_bounds_cpu.max() + block_table_bounds = block_table_bounds_cpu.to(self.device, + non_blocking=True) + mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table_tensor[mask] - - paged_kv_indptr = torch.cat([ - torch.zeros(1, - dtype=block_table_bounds.dtype, - device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32) - ]) - - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) + paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] + + paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, + dtype=torch.int32, + device='cpu') + paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( + dim=0, dtype=torch.int32) + + paged_kv_last_page_len_cpu = seq_lens_cpu % page_size + paged_kv_last_page_len_cpu = torch.where( + paged_kv_last_page_len_cpu == 0, page_size, + paged_kv_last_page_len_cpu) cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -439,10 +448,10 @@ def build(self, kv_cache_dtype = self.kv_cache_spec.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - qo_indptr=qo_indptr, - paged_kv_indptr=paged_kv_indptr, + qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, @@ -456,10 +465,10 @@ def build(self, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - shared_qo_indptr=shared_qo_indptr, - shared_kv_page_indptr=shared_kv_page_indptr, - shared_kv_page_indices=shared_kv_page_indices, - shared_kv_last_page_len=shared_kv_last_page_len, + shared_qo_indptr_cpu=shared_qo_indptr_cpu, + shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, + shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, + shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor,