4
4
import copy
5
5
import gc
6
6
import time
7
- import weakref
8
7
from contextlib import contextmanager
9
8
from typing import TYPE_CHECKING , Any , Optional , Union
10
9
64
63
from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
65
64
from vllm .v1 .spec_decode .utils import is_spec_decode_supported
66
65
from vllm .v1 .utils import bind_kv_cache
67
- from vllm .v1 .worker .block_table import BlockTable
68
66
from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
69
67
from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
70
68
@@ -610,7 +608,7 @@ def _prepare_inputs(
610
608
611
609
# OPTIMIZATION: Start copying the block table first.
612
610
# This way, we can overlap the copy with the following CPU operations.
613
- self .input_batch .block_table .commit (num_reqs )
611
+ self .input_batch .block_table .commit_block_table (num_reqs )
614
612
615
613
# Get the number of scheduled tokens for each request.
616
614
req_ids = self .input_batch .req_ids
@@ -654,29 +652,10 @@ def _prepare_inputs(
654
652
torch .from_numpy (token_indices ),
655
653
out = self .input_ids_cpu [:total_num_scheduled_tokens ])
656
654
657
- # Calculate the slot mapping for each KV cache group.
658
- for kv_cache_group_id , kv_cache_group_spec in enumerate (
659
- self .kv_cache_config .kv_cache_groups ):
660
- block_size = kv_cache_group_spec .kv_cache_spec .block_size
661
- block_table : BlockTable = self .input_batch .block_table [
662
- kv_cache_group_id ]
663
- # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
664
- # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
665
- # where K is the max_num_blocks_per_req and the block size is 2.
666
- # NOTE(woosuk): We can't simply use `token_indices // block_size`
667
- # here because M (max_model_len) is not necessarily divisible by
668
- # block_size.
669
- block_table_indices = (
670
- req_indices * block_table .max_num_blocks_per_req +
671
- positions_np // block_size )
672
- block_table_cpu = block_table .get_cpu_tensor ()
673
- block_numbers = block_table_cpu .flatten (
674
- )[block_table_indices ].numpy ()
675
- block_offsets = positions_np % block_size
676
- np .add (
677
- block_numbers * block_size ,
678
- block_offsets ,
679
- out = block_table .slot_mapping_np [:total_num_scheduled_tokens ])
655
+ self .input_batch .block_table .compute_slot_mapping (
656
+ req_indices , positions_np )
657
+ self .input_batch .block_table .commit_slot_mapping (
658
+ total_num_scheduled_tokens )
680
659
681
660
# Prepare the attention metadata.
682
661
self .query_start_loc_np [0 ] = 0
@@ -722,12 +701,6 @@ def _prepare_inputs(
722
701
for kv_cache_group_id , kv_cache_group_spec in enumerate (
723
702
self .kv_cache_config .kv_cache_groups ):
724
703
725
- slot_mapping = self .input_batch .block_table [
726
- kv_cache_group_id ].slot_mapping [:num_reqs ]
727
- slot_mapping .copy_ (self .input_batch .block_table [kv_cache_group_id ].
728
- slot_mapping_np [:num_reqs ],
729
- non_blocking = True )
730
-
731
704
common_attn_metadata = CommonAttentionMetadata (
732
705
query_start_loc = self .query_start_loc [:num_reqs + 1 ],
733
706
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -740,7 +713,8 @@ def _prepare_inputs(
740
713
max_query_len = max_num_scheduled_tokens ,
741
714
block_table_tensor = self .input_batch .
742
715
block_table [kv_cache_group_id ].get_device_tensor ()[:num_reqs ],
743
- slot_mapping = slot_mapping ,
716
+ slot_mapping = self .input_batch .block_table [kv_cache_group_id ].
717
+ slot_mapping [:total_num_scheduled_tokens ],
744
718
)
745
719
746
720
if self .speculative_config and \
@@ -1679,8 +1653,7 @@ def propose_draft_token_ids(
1679
1653
for i , n in enumerate (num_draft_tokens )
1680
1654
]
1681
1655
num_rejected_tokens_cpu = torch .tensor (num_rejected_tokens ,
1682
- dtype = torch .int32 ,
1683
- device = self .device )
1656
+ dtype = torch .int32 )
1684
1657
num_tokens = (num_scheduled_tokens -
1685
1658
num_rejected_tokens_cpu .sum ())
1686
1659
common_attn_metadata , token_indices = \
@@ -2389,11 +2362,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
2389
2362
raise ValueError (
2390
2363
f"Unknown KV cache spec type: { type (kv_cache_spec )} " )
2391
2364
2392
- block_table_i = self .input_batch .block_table [i ]
2393
2365
attn_metadata_builder_i = attn_backend_i .get_builder_cls ()(
2394
- weakref .proxy (self ),
2395
2366
kv_cache_spec ,
2396
- block_table_i ,
2367
+ self .vllm_config ,
2368
+ self .device ,
2397
2369
)
2398
2370
2399
2371
if (self .full_cuda_graph
0 commit comments