From 8af5f3b61b3fb3323013e1d0b409b08d760fae14 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 16 Jul 2025 23:02:04 -0400 Subject: [PATCH 1/4] host buffers Signed-off-by: Lucas Wilkinson Optimize V1 FlashInfer backend to use CPU host buffers - Replace GPU-to-CPU transfers with direct CPU tensor construction - Build planning tensors from existing CommonAttentionMetadata CPU buffers - Reduce from 6x to 1x .cpu() calls during FlashInfer planning - Fix test mocks to handle correct argument count - Maintain compatibility with GPUModelRunner and FlashInfer V1 backend Signed-off-by: Lucas Wilkinson dont transfer block table Signed-off-by: Lucas Wilkinson optimize Signed-off-by: Lucas Wilkinson --- tests/v1/attention/test_attention_backends.py | 2 +- vllm/v1/attention/backends/flashinfer.py | 162 ++++++++++-------- 2 files changed, 96 insertions(+), 68 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b4e0101a0d4b..60234b4add50 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -212,7 +212,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, from vllm.v1.attention.backends.flashinfer import PerLayerParameters - def mock_get_per_layer_parameters(vllm_config): + def mock_get_per_layer_parameters(vllm_config, impl_cls): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1eb27d57acf0..4890a79eeb6d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -7,17 +7,18 @@ from typing import TYPE_CHECKING, Any, Optional import torch + +import vllm.envs as envs from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import trtllm_batch_decode_with_kv_cache - -import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import cdiv from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, @@ -167,13 +168,13 @@ class FlashInferMetadata: # [0, 5, 8, 1, 6, 7, 3, 4] # paged_kv_indptr is used to index into paged_kv_indices: # [0, 3, 6, 8] - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: torch.Tensor - # The page indices of the paged kv cache + # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan) + paged_kv_indptr_cpu: torch.Tensor + # The page indices of the paged kv cache (on device for plan) paged_kv_indices: torch.Tensor # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_len: torch.Tensor + # the paged kv cache, shape: [batch_size] (CPU for plan) + paged_kv_last_page_len_cpu: torch.Tensor # The number of query/output heads num_qo_heads: int # The number of key/value heads @@ -201,17 +202,20 @@ class FlashInferMetadata: num_prefills: int num_prefill_tokens: int - # For cascade attention. + # For cascade attention (CPU for planning). use_cascade: bool - shared_qo_indptr: Optional[torch.Tensor] = None - shared_kv_page_indptr: Optional[torch.Tensor] = None - shared_kv_page_indices: Optional[torch.Tensor] = None - shared_kv_last_page_len: Optional[torch.Tensor] = None + shared_qo_indptr_cpu: Optional[torch.Tensor] = None + shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None + shared_kv_page_indices_cpu: Optional[torch.Tensor] = None + shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + # CPU version for FlashInfer planning + qo_indptr_cpu: Optional[torch.Tensor] = None + @property def query_start_loc(self): # The GPUModelRunner expects to be able to access this property. @@ -238,6 +242,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.vllm_config = vllm_config self.cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec + max_num_blocks_per_request = cdiv( + vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + self.block_table_arange = torch.arange(max_num_blocks_per_request, + dtype=torch.int32, + device=self.device) def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -285,21 +295,31 @@ def _plan(self, num_prefills: int, num_decodes: int, if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config, FlashInferImpl)) + + # Ensure CPU tensors are not None + assert attn_metadata.qo_indptr_cpu is not None + assert attn_metadata.paged_kv_indptr_cpu is not None + assert attn_metadata.paged_kv_indices is not None + assert attn_metadata.paged_kv_last_page_len_cpu is not None + if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( - [attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr], [ - attn_metadata.shared_kv_page_indptr, - attn_metadata.paged_kv_indptr + attn_metadata.shared_qo_indptr_cpu, + attn_metadata.qo_indptr_cpu + ], + [ + attn_metadata.shared_kv_page_indptr_cpu, + attn_metadata.paged_kv_indptr_cpu ], [ - attn_metadata.shared_kv_page_indices, + attn_metadata.shared_kv_page_indices_cpu, attn_metadata.paged_kv_indices ], [ - attn_metadata.shared_kv_last_page_len, - attn_metadata.paged_kv_last_page_len + attn_metadata.shared_kv_last_page_len_cpu, + attn_metadata.paged_kv_last_page_len_cpu ], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, @@ -320,22 +340,22 @@ def _plan(self, num_prefills: int, num_decodes: int, # Decodes are first so prefills start after the last decode prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() - assert attn_metadata.qo_indptr[prefill_start:].shape[ + assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ + assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[ 0] == num_prefills + 1 - assert attn_metadata.paged_kv_last_page_len[ + assert attn_metadata.paged_kv_last_page_len_cpu[ prefill_start:].shape[0] == num_prefills # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. - qo_indptr = attn_metadata.qo_indptr[ - prefill_start:] - attn_metadata.qo_indptr[prefill_start] + qo_indptr_cpu = attn_metadata.qo_indptr_cpu[ + prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start] attn_metadata.prefill_wrapper.plan( - qo_indptr, - attn_metadata.paged_kv_indptr[prefill_start:], + qo_indptr_cpu, + attn_metadata.paged_kv_indptr_cpu[prefill_start:], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[prefill_start:], + attn_metadata.paged_kv_last_page_len_cpu[prefill_start:], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -356,9 +376,9 @@ def _plan(self, num_prefills: int, num_decodes: int, attn_metadata.kv_data_type, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:num_decodes + 1], + attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1], attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:num_decodes], + attn_metadata.paged_kv_last_page_len_cpu[:num_decodes], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -382,55 +402,62 @@ def build(self, split_decodes_and_prefills(common_attn_metadata) page_size = self.kv_cache_spec.block_size - device = self.device qo_indptr = common_attn_metadata.query_start_loc max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds = (seq_lens + page_size - 1) // page_size + # Build CPU versions directly from seq_lens_cpu + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size use_cascade = common_prefix_len > 0 if use_cascade: # Grab the blocks of the shared prefix from the first request. assert common_prefix_len % page_size == 0 num_common_kv_blocks = common_prefix_len // page_size - shared_qo_indptr = torch.tensor([0, num_actual_tokens], - dtype=torch.int32, - device=device) - shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], - dtype=torch.int32, - device=device) - shared_kv_page_indices = block_table_tensor[ + + # Create CPU versions directly for cascade (no GPU versions needed) + shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device='cpu') + shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks], + dtype=torch.int32, + device='cpu') + shared_kv_page_indices_cpu = block_table_tensor[ 0, :num_common_kv_blocks] - shared_kv_last_page_len = torch.tensor([page_size], - dtype=torch.int32, - device=device) + shared_kv_last_page_len_cpu = torch.tensor([page_size], + dtype=torch.int32, + device='cpu') + # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds -= num_common_kv_blocks + block_table_bounds_cpu -= num_common_kv_blocks else: - shared_qo_indptr = None - shared_kv_page_indptr = None - shared_kv_page_indices = None - shared_kv_last_page_len = None - - mask = (torch.arange(block_table_tensor.size(1), - dtype=block_table_tensor.dtype, - device=block_table_tensor.device).unsqueeze(0) + shared_qo_indptr_cpu = None + shared_kv_page_indptr_cpu = None + shared_kv_page_indices_cpu = None + shared_kv_last_page_len_cpu = None + + max_num_blocks = block_table_bounds_cpu.max() + block_table_bounds = block_table_bounds_cpu.to(self.device, + non_blocking=True) + mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table_tensor[mask] - - paged_kv_indptr = torch.cat([ - torch.zeros(1, - dtype=block_table_bounds.dtype, - device=block_table_bounds.device), - block_table_bounds.cumsum(dim=0, dtype=torch.int32) - ]) - - paged_kv_last_page_len = seq_lens % page_size - paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, - page_size, paged_kv_last_page_len) + paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] + + # paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu + paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, + dtype=torch.int32, + device='cpu') + paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( + dim=0, dtype=torch.int32) + + # paged_kv_last_page_len_cpu: from seq_lens_cpu + paged_kv_last_page_len_cpu = seq_lens_cpu % page_size + paged_kv_last_page_len_cpu = torch.where( + paged_kv_last_page_len_cpu == 0, page_size, + paged_kv_last_page_len_cpu) cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -440,9 +467,10 @@ def build(self, attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, - paged_kv_indptr=paged_kv_indptr, + qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, + paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu, num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, @@ -456,10 +484,10 @@ def build(self, num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, - shared_qo_indptr=shared_qo_indptr, - shared_kv_page_indptr=shared_kv_page_indptr, - shared_kv_page_indices=shared_kv_page_indices, - shared_kv_last_page_len=shared_kv_last_page_len, + shared_qo_indptr_cpu=shared_qo_indptr_cpu, + shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu, + shared_kv_page_indices_cpu=shared_kv_page_indices_cpu, + shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table_tensor=block_table_tensor, From 810ab8c82b6de2434128599ea9501ca9ba8a523a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 17 Jul 2025 23:46:40 -0400 Subject: [PATCH 2/4] reorder imports Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flashinfer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4890a79eeb6d..8b5ab6b242f3 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -8,11 +8,12 @@ import torch -import vllm.envs as envs from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import trtllm_batch_decode_with_kv_cache + +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.config import VllmConfig From 2c215774e0f2705e3806f003dd121d76105479ff Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 17 Jul 2025 23:52:51 -0400 Subject: [PATCH 3/4] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flashinfer.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8b5ab6b242f3..77b74c72e694 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Optional import torch - from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) @@ -297,12 +296,6 @@ def _plan(self, num_prefills: int, num_decodes: int, self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(self.vllm_config, FlashInferImpl)) - # Ensure CPU tensors are not None - assert attn_metadata.qo_indptr_cpu is not None - assert attn_metadata.paged_kv_indptr_cpu is not None - assert attn_metadata.paged_kv_indices is not None - assert attn_metadata.paged_kv_last_page_len_cpu is not None - if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( @@ -406,10 +399,9 @@ def build(self, qo_indptr = common_attn_metadata.query_start_loc max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_tensor = common_attn_metadata.block_table_tensor - # Build CPU versions directly from seq_lens_cpu - seq_lens_cpu = common_attn_metadata.seq_lens_cpu block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size use_cascade = common_prefix_len > 0 @@ -447,14 +439,12 @@ def build(self, < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] - # paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, dtype=torch.int32, device='cpu') paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( dim=0, dtype=torch.int32) - # paged_kv_last_page_len_cpu: from seq_lens_cpu paged_kv_last_page_len_cpu = seq_lens_cpu % page_size paged_kv_last_page_len_cpu = torch.where( paged_kv_last_page_len_cpu == 0, page_size, From 155e954987748e50e38d1085762007a37f0fbb77 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 18 Jul 2025 00:06:09 -0400 Subject: [PATCH 4/4] cleanup Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flashinfer.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 77b74c72e694..e931fa9a2153 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -159,7 +159,7 @@ class FlashInferMetadata: # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - qo_indptr: torch.Tensor + qo_indptr_cpu: torch.Tensor # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -213,14 +213,6 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None - # CPU version for FlashInfer planning - qo_indptr_cpu: Optional[torch.Tensor] = None - - @property - def query_start_loc(self): - # The GPUModelRunner expects to be able to access this property. - return self.qo_indptr - def __post_init__(self): if self.head_dim is not None: FlashInferBackend.validate_head_size(self.head_dim) @@ -396,7 +388,6 @@ def build(self, split_decodes_and_prefills(common_attn_metadata) page_size = self.kv_cache_spec.block_size - qo_indptr = common_attn_metadata.query_start_loc max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu @@ -457,7 +448,6 @@ def build(self, kv_cache_dtype = self.kv_cache_spec.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, - qo_indptr=qo_indptr, qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu, paged_kv_indptr_cpu=paged_kv_indptr_cpu, paged_kv_indices=paged_kv_indices,