|
18 | 18 | from vllm.config import VllmConfig
|
19 | 19 | from vllm.logger import init_logger
|
20 | 20 | from vllm.platforms import current_platform
|
| 21 | +from vllm.utils import cdiv |
21 | 22 | from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
22 | 23 | from vllm.v1.attention.backends.utils import (
|
23 | 24 | AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
|
@@ -241,6 +242,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
|
241 | 242 | self.vllm_config = vllm_config
|
242 | 243 | self.cache_config = vllm_config.cache_config
|
243 | 244 | self.kv_cache_spec = kv_cache_spec
|
| 245 | + max_num_blocks_per_request = cdiv( |
| 246 | + vllm_config.model_config.max_model_len, |
| 247 | + self.kv_cache_spec.block_size) |
| 248 | + self.block_table_arange = torch.arange(max_num_blocks_per_request, |
| 249 | + dtype=torch.int32, |
| 250 | + device=self.device) |
244 | 251 |
|
245 | 252 | def reorder_batch(self, input_batch: InputBatch,
|
246 | 253 | scheduler_output: SchedulerOutput) -> bool:
|
@@ -432,19 +439,19 @@ def build(self,
|
432 | 439 | shared_kv_page_indices_cpu = None
|
433 | 440 | shared_kv_last_page_len_cpu = None
|
434 | 441 |
|
435 |
| - # Build CPU versions directly from CPU data |
436 |
| - # paged_kv_indices_cpu: extract from block_table on CPU |
437 |
| - mask_cpu = (torch.arange(block_table_tensor.size(1), |
438 |
| - dtype=torch.int32, |
439 |
| - device='cpu').unsqueeze(0) |
440 |
| - < block_table_bounds_cpu.unsqueeze(1)) |
441 |
| - paged_kv_indices = block_table_tensor[mask_cpu] |
| 442 | + max_num_blocks = block_table_bounds_cpu.max() |
| 443 | + block_table_bounds = block_table_bounds_cpu.to(self.device, |
| 444 | + non_blocking=True) |
| 445 | + mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) |
| 446 | + < block_table_bounds.unsqueeze(1)) |
| 447 | + paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask] |
442 | 448 |
|
443 | 449 | # paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
|
444 |
| - paged_kv_indptr_cpu = torch.cat([ |
445 |
| - torch.zeros(1, dtype=torch.int32, device='cpu'), |
446 |
| - block_table_bounds_cpu.cumsum(dim=0, dtype=torch.int32) |
447 |
| - ]) |
| 450 | + paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1, |
| 451 | + dtype=torch.int32, |
| 452 | + device='cpu') |
| 453 | + paged_kv_indptr_cpu[1:] = block_table_bounds_cpu.cumsum( |
| 454 | + dim=0, dtype=torch.int32) |
448 | 455 |
|
449 | 456 | # paged_kv_last_page_len_cpu: from seq_lens_cpu
|
450 | 457 | paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
|
|
0 commit comments