Skip to content

Commit 0f199f1

Browse files
[Core] Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue (#21005)
Signed-off-by: Jialin Ouyang <jialino@meta.com>
1 parent b2eb2b5 commit 0f199f1

File tree

4 files changed

+210
-58
lines changed

4 files changed

+210
-58
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import gc
4+
import time
5+
from typing import Optional
6+
7+
from tabulate import tabulate
8+
9+
from vllm.utils import FlexibleArgumentParser
10+
from vllm.v1.core.block_pool import BlockPool
11+
12+
13+
class Metric:
14+
def __init__(self) -> None:
15+
self.cnt: int = 0
16+
self.sum_v: int = 0
17+
self.max_v: Optional[int] = None
18+
19+
def update(self, v: int) -> None:
20+
self.cnt += 1
21+
self.sum_v += v
22+
if self.max_v is None:
23+
self.max_v = v
24+
else:
25+
self.max_v = max(self.max_v, v)
26+
27+
def avg_v(self) -> float:
28+
return self.sum_v * 1.0 / self.cnt
29+
30+
31+
def main(args):
32+
rows = []
33+
for allocate_block in args.allocate_blocks:
34+
# Enforce a GC collect ahead to minimize the impact among runs
35+
gc.collect()
36+
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
37+
38+
get_blocks_metric: Metric = Metric()
39+
free_blocks_metric: Metric = Metric()
40+
for _ in range(args.num_iteration):
41+
t1 = time.monotonic_ns()
42+
blocks = block_pool.get_new_blocks(allocate_block)
43+
t2 = time.monotonic_ns()
44+
block_pool.free_blocks(blocks)
45+
t3 = time.monotonic_ns()
46+
get_blocks_metric.update(t2 - t1)
47+
free_blocks_metric.update(t3 - t2)
48+
49+
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
50+
rows.append(
51+
[
52+
get_blocks_metric.cnt,
53+
args.num_gpu_blocks,
54+
allocate_block,
55+
get_blocks_metric.avg_v() / 1000000,
56+
get_blocks_metric.max_v / 1000000.0,
57+
free_blocks_metric.avg_v() / 1000000,
58+
free_blocks_metric.max_v / 1000000.0,
59+
]
60+
)
61+
else:
62+
print(
63+
"No valid metrics found."
64+
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
65+
)
66+
67+
print(
68+
tabulate(
69+
rows,
70+
headers=[
71+
"Iterations",
72+
"Total\nBlocks",
73+
"Allocated\nBlocks",
74+
"Get Blocks\nAvg (ms)",
75+
"Get Blocks\nMax (ms)",
76+
"Free Blocks\nAvg (ms)",
77+
"Free Blocks\nMax (ms)",
78+
],
79+
tablefmt="grid",
80+
floatfmt=".6f",
81+
)
82+
)
83+
84+
85+
def invoke_main() -> None:
86+
parser = FlexibleArgumentParser(
87+
description="Benchmark the performance of BlockPool for KV Cache."
88+
)
89+
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
90+
parser.add_argument(
91+
"--num-iteration",
92+
type=int,
93+
default=1000,
94+
help="Number of iterations to run to stablize final data readings",
95+
)
96+
parser.add_argument(
97+
"--allocate-blocks",
98+
type=int,
99+
nargs="*",
100+
default=[10, 50, 100, 500, 1000],
101+
help="Number of blocks to allocate",
102+
)
103+
args = parser.parse_args()
104+
main(args)
105+
106+
107+
if __name__ == "__main__":
108+
invoke_main() # pragma: no cover

tests/v1/core/test_kv_cache_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def test_free_kv_cache_block_queue_initialization():
132132
block = KVCacheBlock(block_id=0)
133133
queue = FreeKVCacheBlockQueue([block])
134134
assert queue.num_free_blocks == 1
135-
assert queue.free_list_head == block
136-
assert queue.free_list_tail == block
135+
assert queue.fake_free_list_head.next_free_block is block
136+
assert queue.fake_free_list_tail.prev_free_block is block
137137

138138

139139
def test_free_kv_cache_block_queue_operations():
@@ -145,36 +145,38 @@ def test_free_kv_cache_block_queue_operations():
145145

146146
# Check initial state
147147
assert queue.num_free_blocks == 5
148-
assert queue.free_list_head == blocks[0]
149-
assert queue.free_list_tail == blocks[4]
148+
assert queue.fake_free_list_head.next_free_block is blocks[0]
149+
assert queue.fake_free_list_tail.prev_free_block is blocks[4]
150150

151151
# Pop the first block
152152
block1 = queue.popleft()
153153
assert block1 == blocks[0]
154154
assert queue.num_free_blocks == 4
155-
assert queue.free_list_head == blocks[1]
156-
assert queue.free_list_tail == blocks[4]
155+
assert queue.fake_free_list_head.next_free_block is blocks[1]
156+
assert queue.fake_free_list_tail.prev_free_block is blocks[4]
157157

158158
# Remove a block from the middle
159159
block_to_remove = blocks[2]
160160
queue.remove(block_to_remove)
161161
assert queue.num_free_blocks == 3
162-
assert blocks[1].next_free_block == blocks[3]
163-
assert blocks[3].prev_free_block == blocks[1]
162+
assert blocks[1].next_free_block is blocks[3]
163+
assert blocks[3].prev_free_block is blocks[1]
164164

165165
# Append a block back
166166
queue.append(block_to_remove)
167167
assert queue.num_free_blocks == 4
168-
assert queue.free_list_tail == block_to_remove
169-
assert block_to_remove.prev_free_block == blocks[4]
170-
assert block_to_remove.next_free_block is None
168+
assert queue.fake_free_list_tail.prev_free_block is block_to_remove
169+
assert block_to_remove.prev_free_block is blocks[4]
170+
assert block_to_remove.next_free_block is queue.fake_free_list_tail
171171

172172
# Pop blocks until empty
173173
for _ in range(4):
174174
queue.popleft()
175175
assert queue.num_free_blocks == 0
176-
assert queue.free_list_head is None
177-
assert queue.free_list_tail is None
176+
assert (queue.fake_free_list_head.next_free_block
177+
is queue.fake_free_list_tail)
178+
assert (queue.fake_free_list_tail.prev_free_block
179+
is queue.fake_free_list_head)
178180

179181
# Attempt to pop from an empty queue
180182
with pytest.raises(ValueError) as e:

tests/v1/core/test_prefix_caching.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,14 @@ def test_prefill(hash_algo):
155155
assert block.ref_cnt == 2
156156

157157
# At this point, we should have 5 free blocks left.
158-
assert manager.block_pool.free_block_queue.num_free_blocks == 5
158+
free_block_queue = manager.block_pool.free_block_queue
159+
assert free_block_queue.num_free_blocks == 5
159160

160161
manager.free(req0)
161162
manager.free(req1)
162163

163164
# All blocks should be available.
164-
assert manager.block_pool.free_block_queue.num_free_blocks == 10
165+
assert free_block_queue.num_free_blocks == 10
165166
# The order should be
166167
# [unallocated (6, 7, 8, 9, 10)]
167168
# [unique_req0 (4)]
@@ -188,14 +189,10 @@ def test_prefill(hash_algo):
188189

189190
# Although we only have 6 free blocks, we have 8 blocks in
190191
# the free block queue due to lazy removal.
191-
assert manager.block_pool.free_block_queue.num_free_blocks == 6
192-
assert all([
193-
b.ref_cnt == 0
194-
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
195-
])
196-
assert len([
197-
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
198-
]) == 6
192+
assert free_block_queue.num_free_blocks == 6
193+
assert all(
194+
[b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()])
195+
assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6
199196

200197
manager.free(req2)
201198

@@ -209,9 +206,12 @@ def test_prefill(hash_algo):
209206
computed_blocks)
210207
# This block ID order also checks the eviction order.
211208
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
212-
assert manager.block_pool.free_block_queue.num_free_blocks == 0
213-
assert manager.block_pool.free_block_queue.free_list_head is None
214-
assert manager.block_pool.free_block_queue.free_list_tail is None
209+
210+
assert free_block_queue.num_free_blocks == 0
211+
assert (free_block_queue.fake_free_list_head.next_free_block
212+
is free_block_queue.fake_free_list_tail)
213+
assert (free_block_queue.fake_free_list_tail.prev_free_block
214+
is free_block_queue.fake_free_list_head)
215215

216216

217217
def test_prefill_hybrid_model():

vllm/v1/core/kv_cache_utils.py

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -212,47 +212,81 @@ class FreeKVCacheBlockQueue:
212212
def __init__(self, blocks: list[KVCacheBlock]) -> None:
213213
self.num_free_blocks = len(blocks)
214214

215-
# Initialize the doubly linked list of free blocks.
216-
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
217-
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
215+
# Initialize doubly links of consecutive blocks
218216
for i in range(self.num_free_blocks):
219217
if i > 0:
220218
blocks[i].prev_free_block = blocks[i - 1]
221219
if i < self.num_free_blocks - 1:
222220
blocks[i].next_free_block = blocks[i + 1]
223221

222+
# Create a fake head and a tail block for the doubly linked list to
223+
# reduce branching in the code
224+
#
225+
# The implementation garenteed that the fake head and tail
226+
# are NEVER got popped, so we could safely assume each real blocks
227+
# in the queue has prev and next blocks.
228+
self.fake_free_list_head = KVCacheBlock(block_id=-1)
229+
self.fake_free_list_tail = KVCacheBlock(block_id=-1)
230+
if self.num_free_blocks > 0:
231+
# Connect fake_head and fake_tail to the first and last block
232+
# respectively.
233+
self.fake_free_list_head.next_free_block = blocks[0]
234+
blocks[0].prev_free_block = self.fake_free_list_head
235+
self.fake_free_list_tail.prev_free_block = blocks[-1]
236+
blocks[-1].next_free_block = self.fake_free_list_tail
237+
else:
238+
# For empty list, simply connect the fake head and tail.
239+
self.fake_free_list_head.next_free_block = self.fake_free_list_tail
240+
self.fake_free_list_tail.prev_free_block = self.fake_free_list_head
241+
224242
def popleft(self) -> KVCacheBlock:
225243
"""Pop the first free block and reduce num_free_blocks by 1.
226244
227245
Returns:
228246
The first free block.
229247
"""
230-
if not self.free_list_head:
248+
if (self.fake_free_list_head.next_free_block
249+
is self.fake_free_list_tail
250+
or self.fake_free_list_head.next_free_block is None):
251+
assert self.num_free_blocks == 0, (
252+
f"num_free_blocks ({self.num_free_blocks}) is out of sync "
253+
"with the free list.")
231254
raise ValueError("No free blocks available")
232255

233-
block = self.free_list_head
234-
self.remove(block)
235-
return block
256+
first_block: KVCacheBlock = self.fake_free_list_head.next_free_block
257+
258+
if first_block.next_free_block is None:
259+
# This should not happen if the block is from the free list.
260+
# It indicates a bug in the caller's logic.
261+
raise RuntimeError("Invalid block found in popleft() "
262+
"which doesn't have a valid next_free_block")
263+
264+
# Connect fake_head and the next block of first_block (i.e. second block
265+
# or fake tail).
266+
self.fake_free_list_head.next_free_block = first_block.next_free_block
267+
first_block.next_free_block.prev_free_block = self.fake_free_list_head
268+
269+
# Remove the block from the linked list.
270+
first_block.prev_free_block = first_block.next_free_block = None
271+
272+
self.num_free_blocks -= 1
273+
return first_block
236274

237275
def remove(self, block: KVCacheBlock) -> None:
238276
"""Remove a block in the free list and reduce num_free_blocks by 1.
239277
240278
Args:
241279
block: The block to remove.
242280
"""
243-
if block.prev_free_block is not None:
244-
# Link the previous block to the next block.
245-
block.prev_free_block.next_free_block = block.next_free_block
246-
if block.next_free_block is not None:
247-
# Link the next block to the previous block.
248-
block.next_free_block.prev_free_block = block.prev_free_block
249-
250-
if block == self.free_list_head:
251-
# Update the head if the block is the head.
252-
self.free_list_head = block.next_free_block
253-
if block == self.free_list_tail:
254-
# Update the tail if the block is the tail.
255-
self.free_list_tail = block.prev_free_block
281+
if block.prev_free_block is None or block.next_free_block is None:
282+
# This should not happen if the block is from the free list.
283+
# It indicates a bug in the caller's logic.
284+
raise RuntimeError(f"remove() called on an invalid block: {block}")
285+
286+
# Link the previous block to the next block.
287+
block.prev_free_block.next_free_block = block.next_free_block
288+
# Link the next block to the previous block.
289+
block.next_free_block.prev_free_block = block.prev_free_block
256290

257291
# Remove the block from the linked list.
258292
block.prev_free_block = block.next_free_block = None
@@ -265,17 +299,19 @@ def append(self, block: KVCacheBlock) -> None:
265299
Args:
266300
block: The block to append.
267301
"""
268-
if self.free_list_tail is not None:
269-
# Link the last block to the new block.
270-
self.free_list_tail.next_free_block = block
271-
block.prev_free_block = self.free_list_tail
272-
self.free_list_tail = block
273-
else:
274-
# The free list is empty.
275-
assert self.free_list_head is None
276-
self.free_list_head = self.free_list_tail = block
302+
if self.fake_free_list_tail.prev_free_block is None:
303+
raise RuntimeError(
304+
"prev_free_block of fake_free_list_tail should always exist")
305+
last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block
306+
307+
# Connect the new block after the last block.
308+
last_block.next_free_block = block
309+
block.prev_free_block = last_block
310+
311+
# Connect the fake tail after the new block.
312+
block.next_free_block = self.fake_free_list_tail
313+
self.fake_free_list_tail.prev_free_block = block
277314

278-
block.next_free_block = None
279315
self.num_free_blocks += 1
280316

281317
def get_all_free_blocks(self) -> list[KVCacheBlock]:
@@ -285,8 +321,14 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]:
285321
A list of free blocks.
286322
"""
287323
ret = []
288-
curr_block = self.free_list_head
289-
while curr_block is not None:
324+
if self.fake_free_list_head.next_free_block is None:
325+
raise RuntimeError(
326+
"next_free_block of fake_free_list_head should always exist")
327+
# Start from the first block
328+
curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block
329+
# As long as next_free_block is available, we haven't reached to
330+
# the fake tail yet.
331+
while curr_block.next_free_block is not None:
290332
ret.append(curr_block)
291333
curr_block = curr_block.next_free_block
292334
return ret

0 commit comments

Comments
 (0)