13
13
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
14
14
from vllm .v1 .kv_cache_interface import FullAttentionSpec
15
15
16
+ BACKENDS_TO_TEST = ["flash_attn" , "flashinfer" , "flex_attention" ]
17
+
18
+ # Remove flashinfer from the list if it's not available
19
+ try :
20
+ import flashinfer # noqa: F401
21
+ except ImportError :
22
+ BACKENDS_TO_TEST .remove ("flashinfer" )
23
+
16
24
17
25
def _convert_dtype_to_torch (dtype ):
18
26
"""Convert ModelDType to torch.dtype."""
@@ -84,6 +92,9 @@ def __init__(self):
84
92
self ._q_scale = torch .tensor (1.0 )
85
93
self ._k_scale = torch .tensor (1.0 )
86
94
self ._v_scale = torch .tensor (1.0 )
95
+ # Add float versions for flashinfer
96
+ self ._k_scale_float = 1.0
97
+ self ._v_scale_float = 1.0
87
98
88
99
89
100
def run_attention_backend (backend_name : str , kv_cache_spec : FullAttentionSpec ,
@@ -96,22 +107,52 @@ def run_attention_backend(backend_name: str, kv_cache_spec: FullAttentionSpec,
96
107
97
108
builder_cls , impl_cls = get_attention_backend (backend_name )
98
109
99
- # Build metadata
100
- builder = builder_cls (kv_cache_spec , vllm_config , device )
101
- attn_metadata = builder .build (
102
- common_prefix_len = 0 ,
103
- common_attn_metadata = common_attn_metadata ,
104
- )
110
+ # Mock flashinfer's get_per_layer_parameters if needed
111
+ if backend_name == "flashinfer" :
112
+ import unittest .mock
113
+
114
+ from vllm .v1 .attention .backends .flashinfer import PerLayerParameters
115
+
116
+ def mock_get_per_layer_parameters (vllm_config ):
117
+ # Return mock parameters for a single layer
118
+ head_size = vllm_config .model_config .get_head_size ()
119
+ return {
120
+ "mock_layer" :
121
+ PerLayerParameters (
122
+ window_left = - 1 , # No sliding window
123
+ logits_soft_cap = 0.0 , # No soft cap
124
+ sm_scale = 1.0 / (head_size ** 0.5 ) # Standard scale
125
+ )
126
+ }
127
+
128
+ with unittest .mock .patch (
129
+ 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters' ,
130
+ mock_get_per_layer_parameters ):
131
+ builder = builder_cls (kv_cache_spec , vllm_config , device )
132
+ attn_metadata = builder .build (
133
+ common_prefix_len = 0 ,
134
+ common_attn_metadata = common_attn_metadata ,
135
+ )
136
+ else :
137
+ # Build metadata
138
+ builder = builder_cls (kv_cache_spec , vllm_config , device )
139
+ attn_metadata = builder .build (
140
+ common_prefix_len = 0 ,
141
+ common_attn_metadata = common_attn_metadata ,
142
+ )
105
143
106
144
# Instantiate implementation
107
- num_heads = kv_cache_spec .num_kv_heads
108
- head_size = kv_cache_spec .head_size
145
+ num_heads = vllm_config .model_config .get_num_attention_heads (
146
+ vllm_config .parallel_config )
147
+ num_kv_heads = vllm_config .model_config .get_num_kv_heads (
148
+ vllm_config .parallel_config )
149
+ head_size = vllm_config .model_config .get_head_size ()
109
150
scale = 1.0 / (head_size ** 0.5 )
110
151
impl = impl_cls (
111
152
num_heads = num_heads ,
112
153
head_size = head_size ,
113
154
scale = scale ,
114
- num_kv_heads = num_heads ,
155
+ num_kv_heads = num_kv_heads ,
115
156
alibi_slopes = None ,
116
157
sliding_window = None ,
117
158
kv_cache_dtype = "auto" ,
@@ -255,6 +296,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
255
296
batch_spec , vllm_config .cache_config .block_size , device )
256
297
257
298
# 3. Simulate Paged KV Cache and a realistic slot_mapping
299
+ # Note: In vLLM, block_id=0 is reserved as the null block and should not
300
+ # be used
258
301
block_table = common_attn_metadata .block_table_tensor
259
302
num_blocks = vllm_config .cache_config .num_gpu_blocks or 1000
260
303
kv_cache = torch .empty (2 ,
@@ -267,7 +310,9 @@ def test_backend_correctness(batch_spec_name: str, model: str):
267
310
kv_cache_flat = kv_cache .view (2 , - 1 , num_kv_heads , head_size )
268
311
269
312
# Populate the cache with the context tokens
270
- start_block_idx = 0
313
+ # Start from block_id=1 since block_id=0 is considered the null block in
314
+ # vLLM
315
+ start_block_idx = 1
271
316
for i in range (batch_size ):
272
317
k_context , v_context = all_k_context [i ], all_v_context [i ]
273
318
start = start_block_idx * block_size
@@ -279,13 +324,18 @@ def test_backend_correctness(batch_spec_name: str, model: str):
279
324
start_block_idx += cdiv (seq_lens [i ], block_size )
280
325
281
326
blocks_end = start_block_idx
282
- # randomly permute the context blocks
283
- perm = torch .arange (blocks_end ) #torch.randperm(blocks_end)
284
- inv_perm = torch .argsort (perm )
285
- kv_cache = kv_cache [:, perm , ...]
327
+ # randomly permute the context blocks (excluding block 0 which is null)
328
+ perm = torch .randperm (blocks_end -
329
+ 1 ) + 1 # Random permutation starting from block 1
330
+ inv_perm = torch .zeros (blocks_end , dtype = torch .long , device = device )
331
+ inv_perm [1 :] = torch .argsort (
332
+ perm ) + 1 # Add 1 to account for starting from block 1
333
+ kv_cache [:, 1 :blocks_end , ...] = kv_cache [:, perm , ...]
286
334
287
335
# Construct the right block table
288
- start_block_idx = 0
336
+ # Start from block_id=1 since block_id=0 is considered the null block in
337
+ # vLLM
338
+ start_block_idx = 1
289
339
for i in range (batch_size ):
290
340
num_blocks = cdiv (seq_lens [i ], block_size )
291
341
start = start_block_idx
@@ -306,14 +356,23 @@ def test_backend_correctness(batch_spec_name: str, model: str):
306
356
307
357
# 4. Run vLLM backends and compare
308
358
# Note: flex_attention has known Triton kernel compatibility issues
309
- # with test infrastructure
310
- backends_to_test = ["flash_attn" ] # flex_attention has compilation issues
311
- for backend_name in backends_to_test :
359
+ # with test infrastructures
360
+ for backend_name in BACKENDS_TO_TEST :
361
+ # FlashAttentionm + FlexAttention:
362
+ # [2, num_blocks, block_size, num_kv_heads, head_size]
363
+ # FlashInfer:
364
+ # [num_blocks, 2, block_size, num_kv_heads, head_size]
365
+ # Select the appropriate KV cache format for each backend
366
+ kv_cache_for_backend = kv_cache
367
+ if backend_name == "flashinfer" :
368
+ kv_cache_for_backend = kv_cache .transpose (0 , 1 )
369
+
312
370
backend_output = run_attention_backend (backend_name , kv_cache_spec ,
313
371
vllm_config , device ,
314
372
common_attn_metadata ,
315
373
query_vllm , key_vllm ,
316
- value_vllm , kv_cache )
374
+ value_vllm ,
375
+ kv_cache_for_backend )
317
376
318
377
# Check shape and dtype consistency
319
378
assert backend_output .shape == sdpa_output .shape , (
@@ -330,6 +389,14 @@ def test_backend_correctness(batch_spec_name: str, model: str):
330
389
rtol = 1e-5 if backend_output .dtype == torch .float32 else 1e-2
331
390
atol = 1e-4 if backend_output .dtype == torch .float32 else 1e-3
332
391
392
+ # Flashinfer may have slightly different numerical behavior
393
+ if backend_name == "flashinfer" :
394
+ atol = 1e-3 if backend_output .dtype == torch .float32 else 5e-3
395
+
396
+ # Flex_attention may have slightly different numerical behavior
397
+ if backend_name == "flex_attention" :
398
+ atol = 1e-2 if backend_output .dtype == torch .float32 else 1e-2
399
+
333
400
max_diff = torch .max (torch .abs (backend_output - sdpa_output )).item ()
334
401
assert torch .allclose (
335
402
backend_output , sdpa_output , rtol = rtol , atol = atol ), (
0 commit comments