Skip to content

Commit 574454a

Browse files
Avoid KVCacheBlock.__eq__ invocations in FreeKVCacheBlockQueue
Summary: # Optimizations As a common trick for doubly linked list implementation, introducing fake head and tail nodes would significantly reduce the implementation overhead, and help us to get rid of dataclass.__eq__ comparison easily. - No dataclass.__eq__ invocation - Shorter code - Branchless All these combined should yield significant perf improvement for this piece of code. # Observations Per vLLM profiling, kv_cache_manager.allocate_slots consumed non-negligible cost for each prefill. |{F1980260529}|{F1980260481}|{F1980260497}| By zooming in, we could see the stack of FreeKVCacheBlockQueue.popleft is non-trivial. popleft -> remove -> string.__eq__ which is mainly coming from dataclasses (i.e. KVCacheBlock) equal comparison. Per [dataclasses python library doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) ``` dataclasses.dataclass(*, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False, match_args=True, kw_only=False, slots=False, weakref_slot=False) eq: If true (the default), an __eq__() method will be generated. This method compares the class as if it were a tuple of its fields, in order. Both instances in the comparison must be of the identical type. If the class already defines __eq__(), this parameter is ignored. ``` Test Plan: # Result Typically, block_size is set to 16, so in production usage, we might likely allocate 10-1000 blocks. In this range, the optimization gave us up to ~1ms TTFT savings (the improvements are more significant on long inputs). |After|Before| |{F1980286936}|{F1980286941}| Rollback Plan: Reviewed By: CuiCoco Differential Revision: D78292345
1 parent 1e36c86 commit 574454a

File tree

2 files changed

+151
-32
lines changed

2 files changed

+151
-32
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.utils import FlexibleArgumentParser
4+
from typing import Optional
5+
import gc
6+
from tabulate import tabulate
7+
from vllm.v1.core.block_pool import BlockPool
8+
import time
9+
10+
class Metric:
11+
def __init__(self) -> None:
12+
self.cnt: int = 0
13+
self.sum_v: int = 0
14+
self.max_v: Optional[int] = None
15+
16+
def update(self, v: int) -> None:
17+
self.cnt += 1
18+
self.sum_v += v
19+
if self.max_v is None:
20+
self.max_v = v
21+
else:
22+
self.max_v = max(self.max_v, v)
23+
24+
def avg_v(self) -> float:
25+
return self.sum_v * 1.0 / self.cnt
26+
27+
def main(args):
28+
rows = []
29+
for allocate_block in args.allocate_blocks:
30+
# Enforce a GC collect ahead to minimize the impact among runs
31+
gc.collect()
32+
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
33+
34+
get_blocks_metric: Metric = Metric()
35+
free_blocks_metric: Metric = Metric()
36+
for _ in range(args.num_iteration):
37+
t1 = time.monotonic_ns()
38+
blocks = block_pool.get_new_blocks(allocate_block)
39+
t2 = time.monotonic_ns()
40+
block_pool.free_blocks(blocks)
41+
t3 = time.monotonic_ns()
42+
get_blocks_metric.update(t2 - t1)
43+
free_blocks_metric.update(t3 - t2)
44+
45+
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
46+
rows.append(
47+
[
48+
get_blocks_metric.cnt,
49+
args.num_gpu_blocks,
50+
allocate_block,
51+
get_blocks_metric.avg_v() / 1000000,
52+
get_blocks_metric.max_v / 1000000.0,
53+
free_blocks_metric.avg_v() / 1000000,
54+
free_blocks_metric.max_v / 1000000.0,
55+
]
56+
)
57+
else:
58+
print("No valid metrics found."f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}")
59+
60+
print(tabulate(
61+
rows,
62+
headers=[
63+
"Iterations",
64+
"Total\nBlocks",
65+
"Allocated\nBlocks",
66+
"Get Blocks\nAvg (ms)",
67+
"Get Blocks\nMax (ms)",
68+
"Free Blocks\nAvg (ms)",
69+
"Free Blocks\nMax (ms)",
70+
],
71+
tablefmt="grid",
72+
floatfmt=".6f",
73+
)
74+
)
75+
76+
77+
def invoke_main() -> None:
78+
parser = FlexibleArgumentParser(
79+
description="Benchmark the performance of BlockPool for KV Cache."
80+
)
81+
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
82+
parser.add_argument(
83+
"--num-iteration",
84+
type=int,
85+
default=1000,
86+
help="Number of iterations to run to stablize final data readings",
87+
)
88+
parser.add_argument(
89+
"--allocate-blocks",
90+
type=int,
91+
nargs="*",
92+
default=[10, 50, 100, 500, 1000],
93+
help="Number of blocks to allocate",
94+
)
95+
args = parser.parse_args()
96+
main(args)
97+
98+
if __name__ == "__main__":
99+
invoke_main() # pragma: no cover

vllm/v1/core/kv_cache_utils.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -212,70 +212,87 @@ class FreeKVCacheBlockQueue:
212212
def __init__(self, blocks: list[KVCacheBlock]) -> None:
213213
self.num_free_blocks = len(blocks)
214214

215-
# Initialize the doubly linked list of free blocks.
216-
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
217-
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
215+
# Initialize doubly links of consecutive blocks
218216
for i in range(self.num_free_blocks):
219217
if i > 0:
220218
blocks[i].prev_free_block = blocks[i - 1]
221219
if i < self.num_free_blocks - 1:
222220
blocks[i].next_free_block = blocks[i + 1]
221+
222+
# Create a fake head and a tail block for the doubly linked list to reduce
223+
# branching in the code
224+
#
225+
# The implementation garenteed that the fake head and tail are NEVER got popped,
226+
# so we could safely assume each real blocks in the queue has
227+
# prev and next blocks.
228+
self.fake_free_list_head = KVCacheBlock(block_id=-1)
229+
self.fake_free_list_tail = KVCacheBlock(block_id=-1)
230+
if self.num_free_blocks > 0:
231+
# Connect fake_head and fake_tail to the first and last block respectively.
232+
self.fake_free_list_head.next_free_block = blocks[0]
233+
blocks[0].prev_free_block = self.fake_free_list_head
234+
self.fake_free_list_tail.prev_free_block = blocks[-1]
235+
blocks[-1].next_free_block = self.fake_free_list_tail
236+
else:
237+
# For empty list, simply connect the fake head and tail.
238+
self.fake_free_list_head.next_free_block = self.fake_free_list_tail
239+
self.fake_free_list_tail.prev_free_block = self.fake_free_list_head
223240

224241
def popleft(self) -> KVCacheBlock:
225242
"""Pop the first free block and reduce num_free_blocks by 1.
226243
227244
Returns:
228245
The first free block.
229246
"""
230-
if not self.free_list_head:
247+
if self.num_free_blocks <= 0:
231248
raise ValueError("No free blocks available")
232249

233-
block = self.free_list_head
234-
self.remove(block)
235-
return block
250+
first_block: KVCacheBlock = self.fake_free_list_head.next_free_block
251+
252+
# Connect fake_head and the next block of first_block (i.e. second block
253+
# or fake tail).
254+
self.fake_free_list_head.next_free_block = first_block.next_free_block
255+
first_block.next_free_block.prev_free_block = self.fake_free_list_head
256+
257+
# Remove the block from the linked list.
258+
first_block.prev_free_block = first_block.next_free_block = None
259+
260+
self.num_free_blocks -= 1
261+
return first_block
236262

237263
def remove(self, block: KVCacheBlock) -> None:
238264
"""Remove a block in the free list and reduce num_free_blocks by 1.
239265
240266
Args:
241267
block: The block to remove.
242268
"""
243-
if block.prev_free_block is not None:
244-
# Link the previous block to the next block.
245-
block.prev_free_block.next_free_block = block.next_free_block
246-
if block.next_free_block is not None:
247-
# Link the next block to the previous block.
248-
block.next_free_block.prev_free_block = block.prev_free_block
249-
250-
if block == self.free_list_head:
251-
# Update the head if the block is the head.
252-
self.free_list_head = block.next_free_block
253-
if block == self.free_list_tail:
254-
# Update the tail if the block is the tail.
255-
self.free_list_tail = block.prev_free_block
269+
# Link the previous block to the next block.
270+
block.prev_free_block.next_free_block = block.next_free_block
271+
# Link the next block to the previous block.
272+
block.next_free_block.prev_free_block = block.prev_free_block
256273

257274
# Remove the block from the linked list.
258275
block.prev_free_block = block.next_free_block = None
259276
self.num_free_blocks -= 1
260277

278+
261279
def append(self, block: KVCacheBlock) -> None:
262280
"""Put a block back into the free list and increase
263281
num_free_blocks by 1.
264282
265283
Args:
266284
block: The block to append.
267285
"""
268-
if self.free_list_tail is not None:
269-
# Link the last block to the new block.
270-
self.free_list_tail.next_free_block = block
271-
block.prev_free_block = self.free_list_tail
272-
self.free_list_tail = block
273-
else:
274-
# The free list is empty.
275-
assert self.free_list_head is None
276-
self.free_list_head = self.free_list_tail = block
286+
last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block
287+
288+
# Connect the new block after the last block.
289+
last_block.next_free_block = block
290+
block.prev_free_block = last_block
291+
292+
# Connect the fake tail after the new block.
293+
block.next_free_block = self.fake_free_list_tail
294+
self.fake_free_list_tail.prev_free_block = block
277295

278-
block.next_free_block = None
279296
self.num_free_blocks += 1
280297

281298
def get_all_free_blocks(self) -> list[KVCacheBlock]:
@@ -285,8 +302,11 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]:
285302
A list of free blocks.
286303
"""
287304
ret = []
288-
curr_block = self.free_list_head
289-
while curr_block is not None:
305+
# Start from the first block
306+
curr_block = self.fake_free_list_head.next_free_block
307+
# As long as next_free_block is available, we haven't reached to
308+
# the fake tail yet.
309+
while curr_block.next_free_block is not None:
290310
ret.append(curr_block)
291311
curr_block = curr_block.next_free_block
292312
return ret

0 commit comments

Comments
 (0)