Skip to content

Commit 87ccacf

Browse files
optimize
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 19b2d52 commit 87ccacf

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.config import VllmConfig
1919
from vllm.logger import init_logger
2020
from vllm.platforms import current_platform
21+
from vllm.utils import cdiv
2122
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2223
from vllm.v1.attention.backends.utils import (
2324
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
@@ -241,6 +242,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
241242
self.vllm_config = vllm_config
242243
self.cache_config = vllm_config.cache_config
243244
self.kv_cache_spec = kv_cache_spec
245+
max_num_blocks_per_request = cdiv(
246+
vllm_config.model_config.max_model_len,
247+
self.kv_cache_spec.block_size)
248+
self.block_table_arange = torch.arange(max_num_blocks_per_request,
249+
dtype=torch.int32,
250+
device=self.device)
244251

245252
def reorder_batch(self, input_batch: InputBatch,
246253
scheduler_output: SchedulerOutput) -> bool:
@@ -432,19 +439,19 @@ def build(self,
432439
shared_kv_page_indices_cpu = None
433440
shared_kv_last_page_len_cpu = None
434441

435-
# Build CPU versions directly from CPU data
436-
# paged_kv_indices_cpu: extract from block_table on CPU
437-
mask_cpu = (torch.arange(block_table_tensor.size(1),
438-
dtype=torch.int32,
439-
device='cpu').unsqueeze(0)
440-
< block_table_bounds_cpu.unsqueeze(1))
441-
paged_kv_indices = block_table_tensor[mask_cpu]
442+
max_num_blocks = block_table_bounds_cpu.max()
443+
block_table_bounds = block_table_bounds_cpu.to(self.device,
444+
non_blocking=True)
445+
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
446+
< block_table_bounds.unsqueeze(1))
447+
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
442448

443449
# paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
444-
paged_kv_indptr_cpu = torch.cat([
445-
torch.zeros(1, dtype=torch.int32, device='cpu'),
446-
block_table_bounds_cpu.cumsum(dim=0, dtype=torch.int32)
447-
])
450+
paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
451+
dtype=torch.int32,
452+
device='cpu')
453+
paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum(
454+
dim=0, dtype=torch.int32)
448455

449456
# paged_kv_last_page_len_cpu: from seq_lens_cpu
450457
paged_kv_last_page_len_cpu = seq_lens_cpu % page_size

0 commit comments

Comments
 (0)