Skip to content

Commit ad021d5

Browse files
host buffers
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 0748f22 commit ad021d5

File tree

1 file changed

+92
-71
lines changed

1 file changed

+92
-71
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 92 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from typing import TYPE_CHECKING, Any, Optional
88

99
import torch
10+
11+
import vllm.envs as envs
1012
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1113
BatchPrefillWithPagedKVCacheWrapper,
1214
MultiLevelCascadeAttentionWrapper)
1315
from flashinfer.decode import trtllm_batch_decode_with_kv_cache
14-
15-
import vllm.envs as envs
1616
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1717
AttentionType)
1818
from vllm.config import VllmConfig
@@ -167,13 +167,13 @@ class FlashInferMetadata:
167167
# [0, 5, 8, 1, 6, 7, 3, 4]
168168
# paged_kv_indptr is used to index into paged_kv_indices:
169169
# [0, 3, 6, 8]
170-
# The indptr of the paged kv cache, shape: [batch_size + 1]
171-
paged_kv_indptr: torch.Tensor
172-
# The page indices of the paged kv cache
173-
paged_kv_indices: torch.Tensor
170+
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
171+
paged_kv_indptr_cpu: torch.Tensor
172+
# The page indices of the paged kv cache (CPU for plan)
173+
paged_kv_indices_cpu: torch.Tensor
174174
# The number of entries in the last page of each request in
175-
# the paged kv cache, shape: [batch_size]
176-
paged_kv_last_page_len: torch.Tensor
175+
# the paged kv cache, shape: [batch_size] (CPU for plan)
176+
paged_kv_last_page_len_cpu: torch.Tensor
177177
# The number of query/output heads
178178
num_qo_heads: int
179179
# The number of key/value heads
@@ -201,17 +201,20 @@ class FlashInferMetadata:
201201
num_prefills: int
202202
num_prefill_tokens: int
203203

204-
# For cascade attention.
204+
# For cascade attention (CPU for planning).
205205
use_cascade: bool
206-
shared_qo_indptr: Optional[torch.Tensor] = None
207-
shared_kv_page_indptr: Optional[torch.Tensor] = None
208-
shared_kv_page_indices: Optional[torch.Tensor] = None
209-
shared_kv_last_page_len: Optional[torch.Tensor] = None
206+
shared_qo_indptr_cpu: Optional[torch.Tensor] = None
207+
shared_kv_page_indptr_cpu: Optional[torch.Tensor] = None
208+
shared_kv_page_indices_cpu: Optional[torch.Tensor] = None
209+
shared_kv_last_page_len_cpu: Optional[torch.Tensor] = None
210210

211211
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
212212
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
213213
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
214214

215+
# CPU version for FlashInfer planning
216+
qo_indptr_cpu: Optional[torch.Tensor] = None
217+
215218
@property
216219
def query_start_loc(self):
217220
# The GPUModelRunner expects to be able to access this property.
@@ -285,21 +288,31 @@ def _plan(self, num_prefills: int, num_decodes: int,
285288
if self.global_hyperparameters is None:
286289
self.global_hyperparameters = infer_global_hyperparameters(
287290
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
291+
292+
# Ensure CPU tensors are not None
293+
assert attn_metadata.qo_indptr_cpu is not None
294+
assert attn_metadata.paged_kv_indptr_cpu is not None
295+
assert attn_metadata.paged_kv_indices_cpu is not None
296+
assert attn_metadata.paged_kv_last_page_len_cpu is not None
297+
288298
if attn_metadata.use_cascade:
289299
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
290300
attn_metadata.cascade_wrapper.plan(
291-
[attn_metadata.shared_qo_indptr, attn_metadata.qo_indptr],
292301
[
293-
attn_metadata.shared_kv_page_indptr,
294-
attn_metadata.paged_kv_indptr
302+
attn_metadata.shared_qo_indptr_cpu,
303+
attn_metadata.qo_indptr_cpu
304+
],
305+
[
306+
attn_metadata.shared_kv_page_indptr_cpu,
307+
attn_metadata.paged_kv_indptr_cpu
295308
],
296309
[
297-
attn_metadata.shared_kv_page_indices,
298-
attn_metadata.paged_kv_indices
310+
attn_metadata.shared_kv_page_indices_cpu,
311+
attn_metadata.paged_kv_indices_cpu
299312
],
300313
[
301-
attn_metadata.shared_kv_last_page_len,
302-
attn_metadata.paged_kv_last_page_len
314+
attn_metadata.shared_kv_last_page_len_cpu,
315+
attn_metadata.paged_kv_last_page_len_cpu
303316
],
304317
attn_metadata.num_qo_heads,
305318
attn_metadata.num_kv_heads,
@@ -320,22 +333,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320333
# Decodes are first so prefills start after the last decode
321334
prefill_start = num_decodes
322335
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
323-
assert attn_metadata.qo_indptr[prefill_start:].shape[
336+
assert attn_metadata.qo_indptr_cpu[prefill_start:].shape[
324337
0] == num_prefills + 1
325-
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
338+
assert attn_metadata.paged_kv_indptr_cpu[prefill_start:].shape[
326339
0] == num_prefills + 1
327-
assert attn_metadata.paged_kv_last_page_len[
340+
assert attn_metadata.paged_kv_last_page_len_cpu[
328341
prefill_start:].shape[0] == num_prefills
329342
# Since prefill_wrapper.run() will be called with
330343
# query[num_decode_tokens:] we need to adjust the qo_indptr
331344
# to be relative to the start of the prefill queries.
332-
qo_indptr = attn_metadata.qo_indptr[
333-
prefill_start:] - attn_metadata.qo_indptr[prefill_start]
345+
qo_indptr_cpu = attn_metadata.qo_indptr_cpu[
346+
prefill_start:] - attn_metadata.qo_indptr_cpu[prefill_start]
334347
attn_metadata.prefill_wrapper.plan(
335-
qo_indptr,
336-
attn_metadata.paged_kv_indptr[prefill_start:],
337-
attn_metadata.paged_kv_indices,
338-
attn_metadata.paged_kv_last_page_len[prefill_start:],
348+
qo_indptr_cpu,
349+
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
350+
attn_metadata.paged_kv_indices_cpu,
351+
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
339352
attn_metadata.num_qo_heads,
340353
attn_metadata.num_kv_heads,
341354
attn_metadata.head_dim,
@@ -356,9 +369,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
356369
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
357370
attn_metadata.num_kv_heads, attn_metadata.head_dim):
358371
attn_metadata.decode_wrapper.plan(
359-
attn_metadata.paged_kv_indptr[:num_decodes + 1],
360-
attn_metadata.paged_kv_indices,
361-
attn_metadata.paged_kv_last_page_len[:num_decodes],
372+
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
373+
attn_metadata.paged_kv_indices_cpu,
374+
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
362375
attn_metadata.num_qo_heads,
363376
attn_metadata.num_kv_heads,
364377
attn_metadata.head_dim,
@@ -382,55 +395,62 @@ def build(self,
382395
split_decodes_and_prefills(common_attn_metadata)
383396

384397
page_size = self.kv_cache_spec.block_size
385-
device = self.device
386398
qo_indptr = common_attn_metadata.query_start_loc
387399
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
388400
seq_lens = common_attn_metadata.seq_lens
389401
block_table_tensor = common_attn_metadata.block_table_tensor
390402

391-
block_table_bounds = (seq_lens + page_size - 1) // page_size
403+
# Build CPU versions directly from seq_lens_cpu
404+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
405+
block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
392406

393407
use_cascade = common_prefix_len > 0
394408
if use_cascade:
395409
# Grab the blocks of the shared prefix from the first request.
396410
assert common_prefix_len % page_size == 0
397411
num_common_kv_blocks = common_prefix_len // page_size
398-
shared_qo_indptr = torch.tensor([0, num_actual_tokens],
399-
dtype=torch.int32,
400-
device=device)
401-
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
402-
dtype=torch.int32,
403-
device=device)
404-
shared_kv_page_indices = block_table_tensor[
412+
413+
# Create CPU versions directly for cascade (no GPU versions needed)
414+
shared_qo_indptr_cpu = torch.tensor([0, num_actual_tokens],
415+
dtype=torch.int32,
416+
device='cpu')
417+
shared_kv_page_indptr_cpu = torch.tensor([0, num_common_kv_blocks],
418+
dtype=torch.int32,
419+
device='cpu')
420+
shared_kv_page_indices_cpu = block_table_tensor[
405421
0, :num_common_kv_blocks]
406-
shared_kv_last_page_len = torch.tensor([page_size],
407-
dtype=torch.int32,
408-
device=device)
422+
shared_kv_last_page_len_cpu = torch.tensor([page_size],
423+
dtype=torch.int32,
424+
device='cpu')
425+
409426
# Remove the blocks of the shared prefix from all requests.
410427
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
411-
block_table_bounds -= num_common_kv_blocks
428+
block_table_bounds_cpu -= num_common_kv_blocks
412429
else:
413-
shared_qo_indptr = None
414-
shared_kv_page_indptr = None
415-
shared_kv_page_indices = None
416-
shared_kv_last_page_len = None
417-
418-
mask = (torch.arange(block_table_tensor.size(1),
419-
dtype=block_table_tensor.dtype,
420-
device=block_table_tensor.device).unsqueeze(0)
421-
< block_table_bounds.unsqueeze(1))
422-
paged_kv_indices = block_table_tensor[mask]
423-
424-
paged_kv_indptr = torch.cat([
425-
torch.zeros(1,
426-
dtype=block_table_bounds.dtype,
427-
device=block_table_bounds.device),
428-
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
430+
shared_qo_indptr_cpu = None
431+
shared_kv_page_indptr_cpu = None
432+
shared_kv_page_indices_cpu = None
433+
shared_kv_last_page_len_cpu = None
434+
435+
# Build CPU versions directly from CPU data
436+
# paged_kv_indices_cpu: extract from block_table on CPU
437+
mask_cpu = (torch.arange(block_table_tensor.size(1),
438+
dtype=torch.int32,
439+
device='cpu').unsqueeze(0)
440+
< block_table_bounds_cpu.unsqueeze(1))
441+
paged_kv_indices_cpu = block_table_tensor.cpu()[mask_cpu]
442+
443+
# paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
444+
paged_kv_indptr_cpu = torch.cat([
445+
torch.zeros(1, dtype=torch.int32, device='cpu'),
446+
block_table_bounds_cpu.cumsum(dim=0, dtype=torch.int32)
429447
])
430448

431-
paged_kv_last_page_len = seq_lens % page_size
432-
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
433-
page_size, paged_kv_last_page_len)
449+
# paged_kv_last_page_len_cpu: from seq_lens_cpu
450+
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
451+
paged_kv_last_page_len_cpu = torch.where(
452+
paged_kv_last_page_len_cpu == 0, page_size,
453+
paged_kv_last_page_len_cpu)
434454
cache_dtype = self.cache_config.cache_dtype
435455
if cache_dtype.startswith("fp8"):
436456
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
@@ -440,9 +460,10 @@ def build(self,
440460
attn_metadata = FlashInferMetadata(
441461
num_actual_tokens=num_actual_tokens,
442462
qo_indptr=qo_indptr,
443-
paged_kv_indptr=paged_kv_indptr,
444-
paged_kv_indices=paged_kv_indices,
445-
paged_kv_last_page_len=paged_kv_last_page_len,
463+
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
464+
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
465+
paged_kv_indices_cpu=paged_kv_indices_cpu,
466+
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
446467
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
447468
self.vllm_config.parallel_config),
448469
num_kv_heads=self.kv_cache_spec.num_kv_heads,
@@ -456,10 +477,10 @@ def build(self,
456477
num_prefills=num_prefills,
457478
num_prefill_tokens=num_prefill_tokens,
458479
use_cascade=use_cascade,
459-
shared_qo_indptr=shared_qo_indptr,
460-
shared_kv_page_indptr=shared_kv_page_indptr,
461-
shared_kv_page_indices=shared_kv_page_indices,
462-
shared_kv_last_page_len=shared_kv_last_page_len,
480+
shared_qo_indptr_cpu=shared_qo_indptr_cpu,
481+
shared_kv_page_indptr_cpu=shared_kv_page_indptr_cpu,
482+
shared_kv_page_indices_cpu=shared_kv_page_indices_cpu,
483+
shared_kv_last_page_len_cpu=shared_kv_last_page_len_cpu,
463484
max_seq_len=max_seq_len,
464485
seq_lens=seq_lens,
465486
block_table_tensor=block_table_tensor,

0 commit comments

Comments
 (0)