Skip to content

[Attention] Optimize FlashInfer MetadataBuilder Build call #21137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,

from vllm.v1.attention.backends.flashinfer import PerLayerParameters

def mock_get_per_layer_parameters(vllm_config):
def mock_get_per_layer_parameters(vllm_config, impl_cls):
# Return mock parameters for a single layer
head_size = vllm_config.model_config.get_head_size()
return {
Expand Down
155 changes: 82 additions & 73 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
Expand Down Expand Up @@ -158,7 +159,7 @@ class FlashInferMetadata:
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
qo_indptr: torch.Tensor
qo_indptr_cpu: torch.Tensor
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
Expand All @@ -167,13 +168,13 @@ class FlashInferMetadata:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: torch.Tensor
# The page indices of the paged kv cache
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
paged_kv_indptr_cpu: torch.Tensor
# The page indices of the paged kv cache (on device for plan)
paged_kv_indices: torch.Tensor
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: torch.Tensor
# the paged kv cache, shape: [batch_size] (CPU for plan)
paged_kv_last_page_len_cpu: torch.Tensor
# The number of query/output heads
num_qo_heads: int
# The number of key/value heads
Expand Down Expand Up @@ -201,22 +202,17 @@ class FlashInferMetadata:
num_prefills: int
num_prefill_tokens: int

# For cascade attention.
# For cascade attention (CPU for planning).
use_cascade: bool
shared_qo_indptr: Optional[torch.Tensor] = None
shared_kv_page_indptr: Optional[torch.Tensor] = None
shared_kv_page_indices: Optional[torch.Tensor] = None
shared_kv_last_page_len: Optional[torch.Tensor] = None
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None

prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
return self.qo_indptr

def __post_init__(self):
if self.head_dim is not None:
FlashInferBackend.validate_head_size(self.head_dim)
Expand All @@ -238,6 +234,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
max_num_blocks_per_request = cdiv(
vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
self.block_table_arange = torch.arange(max_num_blocks_per_request,
dtype=torch.int32,
device=self.device)

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand Down Expand Up @@ -285,21 +287,25 @@ def _plan(self, num_prefills: int, num_decodes: int,
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))

if attn_metadata.use_cascade:
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
attn_metadata.cascade_wrapper.plan(
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
[
attn_metadata.shared_kv_page_indptr,
attn_metadata.paged_kv_indptr
attn_metadata.shared_qo_indptr_cpu,
attn_metadata.qo_indptr_cpu
],
[
attn_metadata.shared_kv_page_indptr_cpu,
attn_metadata.paged_kv_indptr_cpu
],
[
attn_metadata.shared_kv_page_indices,
attn_metadata.shared_kv_page_indices_cpu,
attn_metadata.paged_kv_indices
],
[
attn_metadata.shared_kv_last_page_len,
attn_metadata.paged_kv_last_page_len
attn_metadata.shared_kv_last_page_len_cpu,
attn_metadata.paged_kv_last_page_len_cpu
],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
Expand All @@ -320,22 +326,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
# Decodes are first so prefills start after the last decode
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr[prefill_start:].shape[
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
0] == num_prefills + 1
assert attn_metadata.paged_kv_last_page_len[
assert attn_metadata.paged_kv_last_page_len_cpu[
prefill_start:].shape[0] == num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
qo_indptr = attn_metadata.qo_indptr[
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
attn_metadata.prefill_wrapper.plan(
qo_indptr,
attn_metadata.paged_kv_indptr[prefill_start:],
qo_indptr_cpu,
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[prefill_start:],
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
Expand All @@ -356,9 +362,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:num_decodes],
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
Expand All @@ -382,55 +388,58 @@ def build(self,
split_decodes_and_prefills(common_attn_metadata)

page_size = self.kv_cache_spec.block_size
device = self.device
qo_indptr = common_attn_metadata.query_start_loc
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor

block_table_bounds = (seq_lens + page_size - 1) // page_size
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size

use_cascade = common_prefix_len > 0
if use_cascade:
# Grab the blocks of the shared prefix from the first request.
assert common_prefix_len % page_size == 0
num_common_kv_blocks = common_prefix_len // page_size
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=device)
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table_tensor[

# Create CPU versions directly for cascade (no GPU versions needed)
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device='cpu')
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device='cpu')
shared_kv_page_indices_cpu = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
shared_kv_last_page_len_cpu = torch.tensor([page_size],
dtype=torch.int32,
device='cpu')

# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
block_table_bounds_cpu -= num_common_kv_blocks
else:
shared_qo_indptr = None
shared_kv_page_indptr = None
shared_kv_page_indices = None
shared_kv_last_page_len = None

mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None

max_num_blocks = block_table_bounds_cpu.max()
block_table_bounds = block_table_bounds_cpu.to(self.device,
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]

paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
dtype=torch.int32,
device='cpu')
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
dim=0, dtype=torch.int32)

paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
paged_kv_last_page_len_cpu = torch.where(
paged_kv_last_page_len_cpu == 0, page_size,
paged_kv_last_page_len_cpu)
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand All @@ -439,10 +448,10 @@ def build(self,
kv_cache_dtype = self.kv_cache_spec.dtype
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
Expand All @@ -456,10 +465,10 @@ def build(self,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr=shared_qo_indptr,
shared_kv_page_indptr=shared_kv_page_indptr,
shared_kv_page_indices=shared_kv_page_indices,
shared_kv_last_page_len=shared_kv_last_page_len,
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table_tensor=block_table_tensor,
Expand Down