Skip to content

Commit fa85ae0

Browse files
move slot table computation into block_table
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 80fb05a commit fa85ae0

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

vllm/v1/worker/block_table.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ class BlockTable:
1414

1515
def __init__(
1616
self,
17+
block_size: int,
1718
max_num_reqs: int,
1819
max_num_blocks_per_req: int,
1920
max_num_batched_tokens: int,
2021
pin_memory: bool,
2122
device: torch.device,
2223
):
24+
self.block_size = block_size
2325
self.max_num_reqs = max_num_reqs
2426
self.max_num_blocks_per_req = max_num_blocks_per_req
2527
self.max_num_batched_tokens = max_num_batched_tokens
@@ -79,10 +81,31 @@ def swap_row(self, src: int, tgt: int) -> None:
7981

8082
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
8183

82-
def commit(self, num_reqs: int) -> None:
84+
def compute_slot_mapping(self, req_indices: np.ndarray,
85+
positions: np.ndarray) -> None:
86+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
87+
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
88+
# where K is the max_num_blocks_per_req and the block size is 2.
89+
# NOTE(woosuk): We can't simply use `token_indices // block_size`
90+
# here because M (max_model_len) is not necessarily divisible by
91+
# block_size.
92+
block_table_indices = (req_indices * self.max_num_blocks_per_req +
93+
positions // self.block_size)
94+
block_table_cpu = self.get_cpu_tensor()
95+
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
96+
block_offsets = positions % self.block_size
97+
np.add(block_numbers * self.block_size,
98+
block_offsets,
99+
out=self.slot_mapping_np[:req_indices.shape[0]])
100+
101+
def commit_block_table(self, num_reqs: int) -> None:
83102
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
84103
non_blocking=True)
85104

105+
def commit_slot_mapping(self, num_tokens: int) -> None:
106+
self.slot_mapping[:num_tokens].copy_(
107+
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
108+
86109
def clear(self) -> None:
87110
self.block_table.fill_(0)
88111
self.block_table_cpu.fill_(0)
@@ -107,7 +130,8 @@ def __init__(self, max_num_reqs: int, max_model_len: int,
107130
max_num_batched_tokens: int, pin_memory: bool,
108131
device: torch.device, block_sizes: list[int]) -> None:
109132
self.block_tables = [
110-
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
133+
BlockTable(block_size, max_num_reqs, cdiv(max_model_len,
134+
block_size),
111135
max_num_batched_tokens, pin_memory, device)
112136
for block_size in block_sizes
113137
]
@@ -129,9 +153,18 @@ def swap_row(self, src: int, tgt: int) -> None:
129153
for block_table in self.block_tables:
130154
block_table.swap_row(src, tgt)
131155

132-
def commit(self, num_reqs: int) -> None:
156+
def compute_slot_mapping(self, req_indices: np.ndarray,
157+
positions: np.ndarray) -> None:
158+
for block_table in self.block_tables:
159+
block_table.compute_slot_mapping(req_indices, positions)
160+
161+
def commit_block_table(self, num_reqs: int) -> None:
162+
for block_table in self.block_tables:
163+
block_table.commit_block_table(num_reqs)
164+
165+
def commit_slot_mapping(self, num_tokens: int) -> None:
133166
for block_table in self.block_tables:
134-
block_table.commit(num_reqs)
167+
block_table.commit_slot_mapping(num_tokens)
135168

136169
def clear(self) -> None:
137170
for block_table in self.block_tables:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import copy
55
import gc
66
import time
7-
import weakref
87
from contextlib import contextmanager
98
from typing import TYPE_CHECKING, Any, Optional, Union
109

@@ -64,7 +63,6 @@
6463
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
6564
from vllm.v1.spec_decode.utils import is_spec_decode_supported
6665
from vllm.v1.utils import bind_kv_cache
67-
from vllm.v1.worker.block_table import BlockTable
6866
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
6967
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
7068

@@ -610,7 +608,7 @@ def _prepare_inputs(
610608

611609
# OPTIMIZATION: Start copying the block table first.
612610
# This way, we can overlap the copy with the following CPU operations.
613-
self.input_batch.block_table.commit(num_reqs)
611+
self.input_batch.block_table.commit_block_table(num_reqs)
614612

615613
# Get the number of scheduled tokens for each request.
616614
req_ids = self.input_batch.req_ids
@@ -654,29 +652,10 @@ def _prepare_inputs(
654652
torch.from_numpy(token_indices),
655653
out=self.input_ids_cpu[:total_num_scheduled_tokens])
656654

657-
# Calculate the slot mapping for each KV cache group.
658-
for kv_cache_group_id, kv_cache_group_spec in enumerate(
659-
self.kv_cache_config.kv_cache_groups):
660-
block_size = kv_cache_group_spec.kv_cache_spec.block_size
661-
block_table: BlockTable = self.input_batch.block_table[
662-
kv_cache_group_id]
663-
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
664-
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
665-
# where K is the max_num_blocks_per_req and the block size is 2.
666-
# NOTE(woosuk): We can't simply use `token_indices // block_size`
667-
# here because M (max_model_len) is not necessarily divisible by
668-
# block_size.
669-
block_table_indices = (
670-
req_indices * block_table.max_num_blocks_per_req +
671-
positions_np // block_size)
672-
block_table_cpu = block_table.get_cpu_tensor()
673-
block_numbers = block_table_cpu.flatten(
674-
)[block_table_indices].numpy()
675-
block_offsets = positions_np % block_size
676-
np.add(
677-
block_numbers * block_size,
678-
block_offsets,
679-
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
655+
self.input_batch.block_table.compute_slot_mapping(
656+
req_indices, positions_np)
657+
self.input_batch.block_table.commit_slot_mapping(
658+
total_num_scheduled_tokens)
680659

681660
# Prepare the attention metadata.
682661
self.query_start_loc_np[0] = 0
@@ -722,12 +701,6 @@ def _prepare_inputs(
722701
for kv_cache_group_id, kv_cache_group_spec in enumerate(
723702
self.kv_cache_config.kv_cache_groups):
724703

725-
slot_mapping = self.input_batch.block_table[
726-
kv_cache_group_id].slot_mapping[:num_reqs]
727-
slot_mapping.copy_(self.input_batch.block_table[kv_cache_group_id].
728-
slot_mapping_np[:num_reqs],
729-
non_blocking=True)
730-
731704
common_attn_metadata = CommonAttentionMetadata(
732705
query_start_loc=self.query_start_loc[:num_reqs + 1],
733706
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -740,7 +713,8 @@ def _prepare_inputs(
740713
max_query_len=max_num_scheduled_tokens,
741714
block_table_tensor=self.input_batch.
742715
block_table[kv_cache_group_id].get_device_tensor()[:num_reqs],
743-
slot_mapping=slot_mapping,
716+
slot_mapping=self.input_batch.block_table[kv_cache_group_id].
717+
slot_mapping[:total_num_scheduled_tokens],
744718
)
745719

746720
if self.speculative_config and \
@@ -1679,8 +1653,7 @@ def propose_draft_token_ids(
16791653
for i, n in enumerate(num_draft_tokens)
16801654
]
16811655
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
1682-
dtype=torch.int32,
1683-
device=self.device)
1656+
dtype=torch.int32)
16841657
num_tokens = (num_scheduled_tokens -
16851658
num_rejected_tokens_cpu.sum())
16861659
common_attn_metadata, token_indices =\
@@ -2389,11 +2362,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
23892362
raise ValueError(
23902363
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
23912364

2392-
block_table_i = self.input_batch.block_table[i]
23932365
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
2394-
weakref.proxy(self),
23952366
kv_cache_spec,
2396-
block_table_i,
2367+
self.vllm_config,
2368+
self.device,
23972369
)
23982370

23992371
if (self.full_cuda_graph

0 commit comments

Comments
 (0)