Skip to content

Commit 8af5f3b

Browse files
host buffers
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Optimize V1 FlashInfer backend to use CPU host buffers - Replace GPU-to-CPU transfers with direct CPU tensor construction - Build planning tensors from existing CommonAttentionMetadata CPU buffers - Reduce from 6x to 1x .cpu() calls during FlashInfer planning - Fix test mocks to handle correct argument count - Maintain compatibility with GPUModelRunner and FlashInfer V1 backend Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> dont transfer block table Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> optimize Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 8a8fc94 commit 8af5f3b

File tree

2 files changed

+96
-68
lines changed

2 files changed

+96
-68
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
212212

213213
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
214214

215-
def mock_get_per_layer_parameters(vllm_config):
215+
def mock_get_per_layer_parameters(vllm_config, impl_cls):
216216
# Return mock parameters for a single layer
217217
head_size = vllm_config.model_config.get_head_size()
218218
return {

vllm/v1/attention/backends/flashinfer.py

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
from typing import TYPE_CHECKING, Any, Optional
88

99
import torch
10+
11+
import vllm.envs as envs
1012
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1113
BatchPrefillWithPagedKVCacheWrapper,
1214
MultiLevelCascadeAttentionWrapper)
1315
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
14-
15-
import vllm.envs as envs
1616
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1717
AttentionType)
1818
from vllm.config import VllmConfig
1919
from vllm.logger import init_logger
2020
from vllm.platforms import current_platform
21+
from vllm.utils import cdiv
2122
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2223
from vllm.v1.attention.backends.utils import (
2324
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
@@ -167,13 +168,13 @@ class FlashInferMetadata:
167168
# [0, 5, 8, 1, 6, 7, 3, 4]
168169
# paged_kv_indptr is used to index into paged_kv_indices:
169170
# [0, 3, 6, 8]
170-
# The indptr of the paged kv cache, shape: [batch_size + 1]
171-
paged_kv_indptr: torch.Tensor
172-
# The page indices of the paged kv cache
171+
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
172+
paged_kv_indptr_cpu: torch.Tensor
173+
# The page indices of the paged kv cache (on device for plan)
173174
paged_kv_indices: torch.Tensor
174175
# The number of entries in the last page of each request in
175-
# the paged kv cache, shape: [batch_size]
176-
paged_kv_last_page_len: torch.Tensor
176+
# the paged kv cache, shape: [batch_size] (CPU for plan)
177+
paged_kv_last_page_len_cpu: torch.Tensor
177178
# The number of query/output heads
178179
num_qo_heads: int
179180
# The number of key/value heads
@@ -201,17 +202,20 @@ class FlashInferMetadata:
201202
num_prefills: int
202203
num_prefill_tokens: int
203204

204-
# For cascade attention.
205+
# For cascade attention (CPU for planning).
205206
use_cascade: bool
206-
shared_qo_indptr: Optional[torch.Tensor] = None
207-
shared_kv_page_indptr: Optional[torch.Tensor] = None
208-
shared_kv_page_indices: Optional[torch.Tensor] = None
209-
shared_kv_last_page_len: Optional[torch.Tensor] = None
207+
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
208+
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
209+
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
210+
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
210211

211212
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
212213
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
213214
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
214215

216+
# CPU version for FlashInfer planning
217+
qo_indptr_cpu: Optional[torch.Tensor] = None
218+
215219
@property
216220
def query_start_loc(self):
217221
# The GPUModelRunner expects to be able to access this property.
@@ -238,6 +242,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
238242
self.vllm_config = vllm_config
239243
self.cache_config = vllm_config.cache_config
240244
self.kv_cache_spec = kv_cache_spec
245+
max_num_blocks_per_request = cdiv(
246+
vllm_config.model_config.max_model_len,
247+
self.kv_cache_spec.block_size)
248+
self.block_table_arange = torch.arange(max_num_blocks_per_request,
249+
dtype=torch.int32,
250+
device=self.device)
241251

242252
def reorder_batch(self, input_batch: InputBatch,
243253
scheduler_output: SchedulerOutput) -> bool:
@@ -285,21 +295,31 @@ def _plan(self, num_prefills: int, num_decodes: int,
285295
if self.global_hyperparameters is None:
286296
self.global_hyperparameters = infer_global_hyperparameters(
287297
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
298+
299+
# Ensure CPU tensors are not None
300+
assert attn_metadata.qo_indptr_cpu is not None
301+
assert attn_metadata.paged_kv_indptr_cpu is not None
302+
assert attn_metadata.paged_kv_indices is not None
303+
assert attn_metadata.paged_kv_last_page_len_cpu is not None
304+
288305
if attn_metadata.use_cascade:
289306
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
290307
attn_metadata.cascade_wrapper.plan(
291-
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
292308
[
293-
attn_metadata.shared_kv_page_indptr,
294-
attn_metadata.paged_kv_indptr
309+
attn_metadata.shared_qo_indptr_cpu,
310+
attn_metadata.qo_indptr_cpu
311+
],
312+
[
313+
attn_metadata.shared_kv_page_indptr_cpu,
314+
attn_metadata.paged_kv_indptr_cpu
295315
],
296316
[
297-
attn_metadata.shared_kv_page_indices,
317+
attn_metadata.shared_kv_page_indices_cpu,
298318
attn_metadata.paged_kv_indices
299319
],
300320
[
301-
attn_metadata.shared_kv_last_page_len,
302-
attn_metadata.paged_kv_last_page_len
321+
attn_metadata.shared_kv_last_page_len_cpu,
322+
attn_metadata.paged_kv_last_page_len_cpu
303323
],
304324
attn_metadata.num_qo_heads,
305325
attn_metadata.num_kv_heads,
@@ -320,22 +340,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320340
# Decodes are first so prefills start after the last decode
321341
prefill_start = num_decodes
322342
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
323-
assert attn_metadata.qo_indptr[prefill_start:].shape[
343+
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
324344
0] == num_prefills + 1
325-
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
345+
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
326346
0] == num_prefills + 1
327-
assert attn_metadata.paged_kv_last_page_len[
347+
assert attn_metadata.paged_kv_last_page_len_cpu[
328348
prefill_start:].shape[0] == num_prefills
329349
# Since prefill_wrapper.run() will be called with
330350
# query[num_decode_tokens:] we need to adjust the qo_indptr
331351
# to be relative to the start of the prefill queries.
332-
qo_indptr = attn_metadata.qo_indptr[
333-
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
352+
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
353+
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
334354
attn_metadata.prefill_wrapper.plan(
335-
qo_indptr,
336-
attn_metadata.paged_kv_indptr[prefill_start:],
355+
qo_indptr_cpu,
356+
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
337357
attn_metadata.paged_kv_indices,
338-
attn_metadata.paged_kv_last_page_len[prefill_start:],
358+
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
339359
attn_metadata.num_qo_heads,
340360
attn_metadata.num_kv_heads,
341361
attn_metadata.head_dim,
@@ -356,9 +376,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
356376
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
357377
attn_metadata.num_kv_heads, attn_metadata.head_dim):
358378
attn_metadata.decode_wrapper.plan(
359-
attn_metadata.paged_kv_indptr[:num_decodes + 1],
379+
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
360380
attn_metadata.paged_kv_indices,
361-
attn_metadata.paged_kv_last_page_len[:num_decodes],
381+
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
362382
attn_metadata.num_qo_heads,
363383
attn_metadata.num_kv_heads,
364384
attn_metadata.head_dim,
@@ -382,55 +402,62 @@ def build(self,
382402
split_decodes_and_prefills(common_attn_metadata)
383403

384404
page_size = self.kv_cache_spec.block_size
385-
device = self.device
386405
qo_indptr = common_attn_metadata.query_start_loc
387406
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
388407
seq_lens = common_attn_metadata.seq_lens
389408
block_table_tensor = common_attn_metadata.block_table_tensor
390409

391-
block_table_bounds = (seq_lens + page_size - 1) // page_size
410+
# Build CPU versions directly from seq_lens_cpu
411+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
412+
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
392413

393414
use_cascade = common_prefix_len > 0
394415
if use_cascade:
395416
# Grab the blocks of the shared prefix from the first request.
396417
assert common_prefix_len % page_size == 0
397418
num_common_kv_blocks = common_prefix_len // page_size
398-
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
399-
dtype=torch.int32,
400-
device=device)
401-
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
402-
dtype=torch.int32,
403-
device=device)
404-
shared_kv_page_indices = block_table_tensor[
419+
420+
# Create CPU versions directly for cascade (no GPU versions needed)
421+
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
422+
dtype=torch.int32,
423+
device='cpu')
424+
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
425+
dtype=torch.int32,
426+
device='cpu')
427+
shared_kv_page_indices_cpu = block_table_tensor[
405428
0, :num_common_kv_blocks]
406-
shared_kv_last_page_len = torch.tensor([page_size],
407-
dtype=torch.int32,
408-
device=device)
429+
shared_kv_last_page_len_cpu = torch.tensor([page_size],
430+
dtype=torch.int32,
431+
device='cpu')
432+
409433
# Remove the blocks of the shared prefix from all requests.
410434
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
411-
block_table_bounds -= num_common_kv_blocks
435+
block_table_bounds_cpu -= num_common_kv_blocks
412436
else:
413-
shared_qo_indptr = None
414-
shared_kv_page_indptr = None
415-
shared_kv_page_indices = None
416-
shared_kv_last_page_len = None
417-
418-
mask = (torch.arange(block_table_tensor.size(1),
419-
dtype=block_table_tensor.dtype,
420-
device=block_table_tensor.device).unsqueeze(0)
437+
shared_qo_indptr_cpu = None
438+
shared_kv_page_indptr_cpu = None
439+
shared_kv_page_indices_cpu = None
440+
shared_kv_last_page_len_cpu = None
441+
442+
max_num_blocks = block_table_bounds_cpu.max()
443+
block_table_bounds = block_table_bounds_cpu.to(self.device,
444+
non_blocking=True)
445+
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
421446
< block_table_bounds.unsqueeze(1))
422-
paged_kv_indices = block_table_tensor[mask]
423-
424-
paged_kv_indptr = torch.cat([
425-
torch.zeros(1,
426-
dtype=block_table_bounds.dtype,
427-
device=block_table_bounds.device),
428-
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
429-
])
430-
431-
paged_kv_last_page_len = seq_lens % page_size
432-
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
433-
page_size, paged_kv_last_page_len)
447+
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
448+
449+
# paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
450+
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
451+
dtype=torch.int32,
452+
device='cpu')
453+
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
454+
dim=0, dtype=torch.int32)
455+
456+
# paged_kv_last_page_len_cpu: from seq_lens_cpu
457+
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
458+
paged_kv_last_page_len_cpu = torch.where(
459+
paged_kv_last_page_len_cpu == 0, page_size,
460+
paged_kv_last_page_len_cpu)
434461
cache_dtype = self.cache_config.cache_dtype
435462
if cache_dtype.startswith("fp8"):
436463
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@@ -440,9 +467,10 @@ def build(self,
440467
attn_metadata = FlashInferMetadata(
441468
num_actual_tokens=num_actual_tokens,
442469
qo_indptr=qo_indptr,
443-
paged_kv_indptr=paged_kv_indptr,
470+
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
471+
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
444472
paged_kv_indices=paged_kv_indices,
445-
paged_kv_last_page_len=paged_kv_last_page_len,
473+
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
446474
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
447475
self.vllm_config.parallel_config),
448476
num_kv_heads=self.kv_cache_spec.num_kv_heads,
@@ -456,10 +484,10 @@ def build(self,
456484
num_prefills=num_prefills,
457485
num_prefill_tokens=num_prefill_tokens,
458486
use_cascade=use_cascade,
459-
shared_qo_indptr=shared_qo_indptr,
460-
shared_kv_page_indptr=shared_kv_page_indptr,
461-
shared_kv_page_indices=shared_kv_page_indices,
462-
shared_kv_last_page_len=shared_kv_last_page_len,
487+
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
488+
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
489+
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
490+
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
463491
max_seq_len=max_seq_len,
464492
seq_lens=seq_lens,
465493
block_table_tensor=block_table_tensor,

0 commit comments

Comments
 (0)