Skip to content

Commit c04904b

Browse files
first pass backend tests working
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent be6c009 commit c04904b

File tree

2 files changed

+93
-22
lines changed

2 files changed

+93
-22
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1414
from vllm.v1.kv_cache_interface import FullAttentionSpec
1515

16+
BACKENDS_TO_TEST = ["flash_attn", "flashinfer", "flex_attention"]
17+
18+
# Remove flashinfer from the list if it's not available
19+
try:
20+
import flashinfer # noqa: F401
21+
except ImportError:
22+
BACKENDS_TO_TEST.remove("flashinfer")
23+
1624

1725
def _convert_dtype_to_torch(dtype):
1826
"""Convert ModelDType to torch.dtype."""
@@ -84,6 +92,9 @@ def __init__(self):
8492
self._q_scale = torch.tensor(1.0)
8593
self._k_scale = torch.tensor(1.0)
8694
self._v_scale = torch.tensor(1.0)
95+
# Add float versions for flashinfer
96+
self._k_scale_float = 1.0
97+
self._v_scale_float = 1.0
8798

8899

89100
def run_attention_backend(backend_name: str, kv_cache_spec: FullAttentionSpec,
@@ -96,22 +107,52 @@ def run_attention_backend(backend_name: str, kv_cache_spec: FullAttentionSpec,
96107

97108
builder_cls, impl_cls = get_attention_backend(backend_name)
98109

99-
# Build metadata
100-
builder = builder_cls(kv_cache_spec, vllm_config, device)
101-
attn_metadata = builder.build(
102-
common_prefix_len=0,
103-
common_attn_metadata=common_attn_metadata,
104-
)
110+
# Mock flashinfer's get_per_layer_parameters if needed
111+
if backend_name == "flashinfer":
112+
import unittest.mock
113+
114+
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
115+
116+
def mock_get_per_layer_parameters(vllm_config):
117+
# Return mock parameters for a single layer
118+
head_size = vllm_config.model_config.get_head_size()
119+
return {
120+
"mock_layer":
121+
PerLayerParameters(
122+
window_left=-1, # No sliding window
123+
logits_soft_cap=0.0, # No soft cap
124+
sm_scale=1.0 / (head_size**0.5) # Standard scale
125+
)
126+
}
127+
128+
with unittest.mock.patch(
129+
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
130+
mock_get_per_layer_parameters):
131+
builder = builder_cls(kv_cache_spec, vllm_config, device)
132+
attn_metadata = builder.build(
133+
common_prefix_len=0,
134+
common_attn_metadata=common_attn_metadata,
135+
)
136+
else:
137+
# Build metadata
138+
builder = builder_cls(kv_cache_spec, vllm_config, device)
139+
attn_metadata = builder.build(
140+
common_prefix_len=0,
141+
common_attn_metadata=common_attn_metadata,
142+
)
105143

106144
# Instantiate implementation
107-
num_heads = kv_cache_spec.num_kv_heads
108-
head_size = kv_cache_spec.head_size
145+
num_heads = vllm_config.model_config.get_num_attention_heads(
146+
vllm_config.parallel_config)
147+
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
148+
vllm_config.parallel_config)
149+
head_size = vllm_config.model_config.get_head_size()
109150
scale = 1.0 / (head_size**0.5)
110151
impl = impl_cls(
111152
num_heads=num_heads,
112153
head_size=head_size,
113154
scale=scale,
114-
num_kv_heads=num_heads,
155+
num_kv_heads=num_kv_heads,
115156
alibi_slopes=None,
116157
sliding_window=None,
117158
kv_cache_dtype="auto",
@@ -255,6 +296,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
255296
batch_spec, vllm_config.cache_config.block_size, device)
256297

257298
# 3. Simulate Paged KV Cache and a realistic slot_mapping
299+
# Note: In vLLM, block_id=0 is reserved as the null block and should not
300+
# be used
258301
block_table = common_attn_metadata.block_table_tensor
259302
num_blocks = vllm_config.cache_config.num_gpu_blocks or 1000
260303
kv_cache = torch.empty(2,
@@ -267,7 +310,9 @@ def test_backend_correctness(batch_spec_name: str, model: str):
267310
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
268311

269312
# Populate the cache with the context tokens
270-
start_block_idx = 0
313+
# Start from block_id=1 since block_id=0 is considered the null block in
314+
# vLLM
315+
start_block_idx = 1
271316
for i in range(batch_size):
272317
k_context, v_context = all_k_context[i], all_v_context[i]
273318
start = start_block_idx * block_size
@@ -279,13 +324,18 @@ def test_backend_correctness(batch_spec_name: str, model: str):
279324
start_block_idx += cdiv(seq_lens[i], block_size)
280325

281326
blocks_end = start_block_idx
282-
# randomly permute the context blocks
283-
perm = torch.arange(blocks_end) #torch.randperm(blocks_end)
284-
inv_perm = torch.argsort(perm)
285-
kv_cache = kv_cache[:, perm, ...]
327+
# randomly permute the context blocks (excluding block 0 which is null)
328+
perm = torch.randperm(blocks_end -
329+
1) + 1 # Random permutation starting from block 1
330+
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
331+
inv_perm[1:] = torch.argsort(
332+
perm) + 1 # Add 1 to account for starting from block 1
333+
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
286334

287335
# Construct the right block table
288-
start_block_idx = 0
336+
# Start from block_id=1 since block_id=0 is considered the null block in
337+
# vLLM
338+
start_block_idx = 1
289339
for i in range(batch_size):
290340
num_blocks = cdiv(seq_lens[i], block_size)
291341
start = start_block_idx
@@ -306,14 +356,23 @@ def test_backend_correctness(batch_spec_name: str, model: str):
306356

307357
# 4. Run vLLM backends and compare
308358
# Note: flex_attention has known Triton kernel compatibility issues
309-
# with test infrastructure
310-
backends_to_test = ["flash_attn"] # flex_attention has compilation issues
311-
for backend_name in backends_to_test:
359+
# with test infrastructures
360+
for backend_name in BACKENDS_TO_TEST:
361+
# FlashAttentionm + FlexAttention:
362+
# [2, num_blocks, block_size, num_kv_heads, head_size]
363+
# FlashInfer:
364+
# [num_blocks, 2, block_size, num_kv_heads, head_size]
365+
# Select the appropriate KV cache format for each backend
366+
kv_cache_for_backend = kv_cache
367+
if backend_name == "flashinfer":
368+
kv_cache_for_backend = kv_cache.transpose(0, 1)
369+
312370
backend_output = run_attention_backend(backend_name, kv_cache_spec,
313371
vllm_config, device,
314372
common_attn_metadata,
315373
query_vllm, key_vllm,
316-
value_vllm, kv_cache)
374+
value_vllm,
375+
kv_cache_for_backend)
317376

318377
# Check shape and dtype consistency
319378
assert backend_output.shape == sdpa_output.shape, (
@@ -330,6 +389,14 @@ def test_backend_correctness(batch_spec_name: str, model: str):
330389
rtol = 1e-5 if backend_output.dtype == torch.float32 else 1e-2
331390
atol = 1e-4 if backend_output.dtype == torch.float32 else 1e-3
332391

392+
# Flashinfer may have slightly different numerical behavior
393+
if backend_name == "flashinfer":
394+
atol = 1e-3 if backend_output.dtype == torch.float32 else 5e-3
395+
396+
# Flex_attention may have slightly different numerical behavior
397+
if backend_name == "flex_attention":
398+
atol = 1e-2 if backend_output.dtype == torch.float32 else 1e-2
399+
333400
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
334401
assert torch.allclose(
335402
backend_output, sdpa_output, rtol=rtol, atol=atol), (

tests/v1/attention/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ def create_common_attn_metadata(
5454
device=device)
5555
seq_lens_cpu = seq_lens.cpu()
5656

57-
# Create computed tokens (assume all tokens are computed for simplicity)
58-
num_computed_tokens_cpu = seq_lens_cpu.clone()
57+
# Create computed tokens (context length for each sequence)
58+
context_lens = [
59+
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
60+
for i in range(batch_spec.batch_size)
61+
]
62+
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
5963

6064
# Create block table (random for testing)
6165
max_blocks = max(batch_spec.seq_lens) // block_size + 1
@@ -126,7 +130,7 @@ def create_standard_kv_cache_spec(
126130
"""Create a FullAttentionSpec from ModelParams only."""
127131
return FullAttentionSpec(
128132
block_size=vllm_config.cache_config.block_size,
129-
num_kv_heads=vllm_config.model_config.get_num_attention_heads(
133+
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
130134
vllm_config.parallel_config),
131135
head_size=vllm_config.model_config.get_head_size(),
132136
dtype=vllm_config.model_config.dtype,

0 commit comments

Comments
 (0)