Skip to content

Commit c6deb39

Browse files
Merge branch 'vllm-project:main' into large-block-size-solution
2 parents d326ac2 + 5782581 commit c6deb39

File tree

8 files changed

+262
-67
lines changed

8 files changed

+262
-67
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import gc
4+
import time
5+
from typing import Optional
6+
7+
from tabulate import tabulate
8+
9+
from vllm.utils import FlexibleArgumentParser
10+
from vllm.v1.core.block_pool import BlockPool
11+
12+
13+
class Metric:
14+
def __init__(self) -> None:
15+
self.cnt: int = 0
16+
self.sum_v: int = 0
17+
self.max_v: Optional[int] = None
18+
19+
def update(self, v: int) -> None:
20+
self.cnt += 1
21+
self.sum_v += v
22+
if self.max_v is None:
23+
self.max_v = v
24+
else:
25+
self.max_v = max(self.max_v, v)
26+
27+
def avg_v(self) -> float:
28+
return self.sum_v * 1.0 / self.cnt
29+
30+
31+
def main(args):
32+
rows = []
33+
for allocate_block in args.allocate_blocks:
34+
# Enforce a GC collect ahead to minimize the impact among runs
35+
gc.collect()
36+
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
37+
38+
get_blocks_metric: Metric = Metric()
39+
free_blocks_metric: Metric = Metric()
40+
for _ in range(args.num_iteration):
41+
t1 = time.monotonic_ns()
42+
blocks = block_pool.get_new_blocks(allocate_block)
43+
t2 = time.monotonic_ns()
44+
block_pool.free_blocks(blocks)
45+
t3 = time.monotonic_ns()
46+
get_blocks_metric.update(t2 - t1)
47+
free_blocks_metric.update(t3 - t2)
48+
49+
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
50+
rows.append(
51+
[
52+
get_blocks_metric.cnt,
53+
args.num_gpu_blocks,
54+
allocate_block,
55+
get_blocks_metric.avg_v() / 1000000,
56+
get_blocks_metric.max_v / 1000000.0,
57+
free_blocks_metric.avg_v() / 1000000,
58+
free_blocks_metric.max_v / 1000000.0,
59+
]
60+
)
61+
else:
62+
print(
63+
"No valid metrics found."
64+
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
65+
)
66+
67+
print(
68+
tabulate(
69+
rows,
70+
headers=[
71+
"Iterations",
72+
"Total\nBlocks",
73+
"Allocated\nBlocks",
74+
"Get Blocks\nAvg (ms)",
75+
"Get Blocks\nMax (ms)",
76+
"Free Blocks\nAvg (ms)",
77+
"Free Blocks\nMax (ms)",
78+
],
79+
tablefmt="grid",
80+
floatfmt=".6f",
81+
)
82+
)
83+
84+
85+
def invoke_main() -> None:
86+
parser = FlexibleArgumentParser(
87+
description="Benchmark the performance of BlockPool for KV Cache."
88+
)
89+
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
90+
parser.add_argument(
91+
"--num-iteration",
92+
type=int,
93+
default=1000,
94+
help="Number of iterations to run to stablize final data readings",
95+
)
96+
parser.add_argument(
97+
"--allocate-blocks",
98+
type=int,
99+
nargs="*",
100+
default=[10, 50, 100, 500, 1000],
101+
help="Number of blocks to allocate",
102+
)
103+
args = parser.parse_args()
104+
main(args)
105+
106+
107+
if __name__ == "__main__":
108+
invoke_main() # pragma: no cover

csrc/torch_bindings.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2020
// vLLM custom ops
2121
//
2222

23-
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
23+
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
24+
// so we need
2425
// to override this for many GEMMs with the following tag. Otherwise,
2526
// torch.compile will force all input tensors to be contiguous(), which
2627
// will break many custom ops that require column-major weight matrices.
27-
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
28-
// to match exact eager-mode strides.
29-
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
28+
// This was a bug and PyTorch 2.7 has since fixed this.
29+
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
30+
#define stride_tag at::Tag::needs_fixed_stride_order
31+
#else
32+
#define stride_tag
33+
#endif
3034

3135
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
3236
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

tests/v1/core/test_kv_cache_utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def test_free_kv_cache_block_queue_initialization():
132132
block = KVCacheBlock(block_id=0)
133133
queue = FreeKVCacheBlockQueue([block])
134134
assert queue.num_free_blocks == 1
135-
assert queue.free_list_head == block
136-
assert queue.free_list_tail == block
135+
assert queue.fake_free_list_head.next_free_block is block
136+
assert queue.fake_free_list_tail.prev_free_block is block
137137

138138

139139
def test_free_kv_cache_block_queue_operations():
@@ -145,36 +145,38 @@ def test_free_kv_cache_block_queue_operations():
145145

146146
# Check initial state
147147
assert queue.num_free_blocks == 5
148-
assert queue.free_list_head == blocks[0]
149-
assert queue.free_list_tail == blocks[4]
148+
assert queue.fake_free_list_head.next_free_block is blocks[0]
149+
assert queue.fake_free_list_tail.prev_free_block is blocks[4]
150150

151151
# Pop the first block
152152
block1 = queue.popleft()
153153
assert block1 == blocks[0]
154154
assert queue.num_free_blocks == 4
155-
assert queue.free_list_head == blocks[1]
156-
assert queue.free_list_tail == blocks[4]
155+
assert queue.fake_free_list_head.next_free_block is blocks[1]
156+
assert queue.fake_free_list_tail.prev_free_block is blocks[4]
157157

158158
# Remove a block from the middle
159159
block_to_remove = blocks[2]
160160
queue.remove(block_to_remove)
161161
assert queue.num_free_blocks == 3
162-
assert blocks[1].next_free_block == blocks[3]
163-
assert blocks[3].prev_free_block == blocks[1]
162+
assert blocks[1].next_free_block is blocks[3]
163+
assert blocks[3].prev_free_block is blocks[1]
164164

165165
# Append a block back
166166
queue.append(block_to_remove)
167167
assert queue.num_free_blocks == 4
168-
assert queue.free_list_tail == block_to_remove
169-
assert block_to_remove.prev_free_block == blocks[4]
170-
assert block_to_remove.next_free_block is None
168+
assert queue.fake_free_list_tail.prev_free_block is block_to_remove
169+
assert block_to_remove.prev_free_block is blocks[4]
170+
assert block_to_remove.next_free_block is queue.fake_free_list_tail
171171

172172
# Pop blocks until empty
173173
for _ in range(4):
174174
queue.popleft()
175175
assert queue.num_free_blocks == 0
176-
assert queue.free_list_head is None
177-
assert queue.free_list_tail is None
176+
assert (queue.fake_free_list_head.next_free_block
177+
is queue.fake_free_list_tail)
178+
assert (queue.fake_free_list_tail.prev_free_block
179+
is queue.fake_free_list_head)
178180

179181
# Attempt to pop from an empty queue
180182
with pytest.raises(ValueError) as e:

tests/v1/core/test_prefix_caching.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,14 @@ def test_prefill(hash_algo):
155155
assert block.ref_cnt == 2
156156

157157
# At this point, we should have 5 free blocks left.
158-
assert manager.block_pool.free_block_queue.num_free_blocks == 5
158+
free_block_queue = manager.block_pool.free_block_queue
159+
assert free_block_queue.num_free_blocks == 5
159160

160161
manager.free(req0)
161162
manager.free(req1)
162163

163164
# All blocks should be available.
164-
assert manager.block_pool.free_block_queue.num_free_blocks == 10
165+
assert free_block_queue.num_free_blocks == 10
165166
# The order should be
166167
# [unallocated (6, 7, 8, 9, 10)]
167168
# [unique_req0 (4)]
@@ -188,14 +189,10 @@ def test_prefill(hash_algo):
188189

189190
# Although we only have 6 free blocks, we have 8 blocks in
190191
# the free block queue due to lazy removal.
191-
assert manager.block_pool.free_block_queue.num_free_blocks == 6
192-
assert all([
193-
b.ref_cnt == 0
194-
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
195-
])
196-
assert len([
197-
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
198-
]) == 6
192+
assert free_block_queue.num_free_blocks == 6
193+
assert all(
194+
[b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()])
195+
assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6
199196

200197
manager.free(req2)
201198

@@ -209,9 +206,12 @@ def test_prefill(hash_algo):
209206
computed_blocks)
210207
# This block ID order also checks the eviction order.
211208
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
212-
assert manager.block_pool.free_block_queue.num_free_blocks == 0
213-
assert manager.block_pool.free_block_queue.free_list_head is None
214-
assert manager.block_pool.free_block_queue.free_list_tail is None
209+
210+
assert free_block_queue.num_free_blocks == 0
211+
assert (free_block_queue.fake_free_list_head.next_free_block
212+
is free_block_queue.fake_free_list_tail)
213+
assert (free_block_queue.fake_free_list_tail.prev_free_block
214+
is free_block_queue.fake_free_list_head)
215215

216216

217217
def test_prefill_hybrid_model():

vllm/attention/layer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,42 @@
1616
has_kv_transfer_group,
1717
is_v1_kv_transfer_group)
1818
from vllm.forward_context import ForwardContext, get_forward_context
19+
from vllm.logger import init_logger
1920
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
2021
from vllm.model_executor.layers.quantization.base_config import (
2122
QuantizationConfig)
2223
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2324
from vllm.platforms import _Backend, current_platform
2425
from vllm.utils import direct_register_custom_op
2526

27+
logger = init_logger(__name__)
28+
USE_XFORMERS_OPS = None
29+
30+
31+
def check_xformers_availability():
32+
global USE_XFORMERS_OPS
33+
if USE_XFORMERS_OPS is not None:
34+
return USE_XFORMERS_OPS
35+
36+
if current_platform.is_cuda() and current_platform.has_device_capability(
37+
100):
38+
# Xformers FA is not compatible with B200
39+
USE_XFORMERS_OPS = False
40+
else:
41+
try:
42+
from importlib.util import find_spec
43+
44+
find_spec("xformers.ops")
45+
USE_XFORMERS_OPS = True
46+
except ImportError:
47+
USE_XFORMERS_OPS = False
48+
49+
# the warning only needs to be shown once
50+
if not USE_XFORMERS_OPS:
51+
logger.warning("Xformers is not available, falling back.")
52+
53+
return USE_XFORMERS_OPS
54+
2655

2756
class Attention(nn.Module):
2857
"""Attention layer.
@@ -314,6 +343,10 @@ def __init__(
314343
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
315344
} else _Backend.TORCH_SDPA
316345

346+
if (self.attn_backend == _Backend.XFORMERS
347+
and not check_xformers_availability()):
348+
self.attn_backend = _Backend.TORCH_SDPA
349+
317350
def forward(
318351
self,
319352
query: torch.Tensor,

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.platforms import current_platform
9-
from vllm.utils import direct_register_custom_op
9+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
1010

1111

1212
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
9393

9494

9595
if current_platform.is_rocm():
96+
if is_torch_equal_or_newer("2.7.0"):
97+
tags = ()
98+
else:
99+
tags = (torch.Tag.needs_fixed_stride_order, ),
96100
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
97101
op_func=mla_decode_fwd_impl,
98102
mutates_args=["o"],
99103
fake_impl=mla_decode_fwd_fake,
100-
tags=[torch.Tag.needs_fixed_stride_order])
104+
tags=tags)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
dequant_mxfp4)
3434
from vllm.platforms import current_platform
3535
from vllm.triton_utils import tl, triton
36-
from vllm.utils import direct_register_custom_op
36+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
3737
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
3838

3939
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
@@ -1056,7 +1056,8 @@ def inplace_fused_experts_fake(
10561056
op_func=inplace_fused_experts,
10571057
mutates_args=["hidden_states"],
10581058
fake_impl=inplace_fused_experts_fake,
1059-
tags=(torch.Tag.needs_fixed_stride_order, ),
1059+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1060+
(torch.Tag.needs_fixed_stride_order, )),
10601061
)
10611062

10621063

@@ -1122,7 +1123,8 @@ def outplace_fused_experts_fake(
11221123
op_func=outplace_fused_experts,
11231124
mutates_args=[],
11241125
fake_impl=outplace_fused_experts_fake,
1125-
tags=(torch.Tag.needs_fixed_stride_order, ),
1126+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1127+
(torch.Tag.needs_fixed_stride_order, )),
11261128
)
11271129

11281130

0 commit comments

Comments
 (0)