5
5
import pytest
6
6
import torch
7
7
8
- from tests .v1 .attention .utils import (BatchSpec , create_common_attn_metadata ,
8
+ from tests .v1 .attention .utils import (BatchSpec , _Backend ,
9
+ create_common_attn_metadata ,
9
10
create_standard_kv_cache_spec ,
10
11
create_vllm_config ,
11
12
get_attention_backend )
12
13
from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , cdiv
13
14
from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
14
15
from vllm .v1 .kv_cache_interface import FullAttentionSpec
15
16
16
- BACKENDS_TO_TEST = ["flash_attn" , "flashinfer" , "flex_attention" ]
17
+ BACKENDS_TO_TEST = [
18
+ _Backend .FLASH_ATTN_VLLM_V1 , _Backend .FLASHINFER_VLLM_V1 ,
19
+ _Backend .FLEX_ATTENTION , _Backend .TRITON_ATTN_VLLM_V1
20
+ ]
17
21
18
22
# Remove flashinfer from the list if it's not available
19
23
try :
20
24
import flashinfer # noqa: F401
21
25
except ImportError :
22
- BACKENDS_TO_TEST .remove ("flashinfer" )
26
+ BACKENDS_TO_TEST .remove (_Backend . FLASHINFER_VLLM_V1 )
23
27
24
28
25
29
def _convert_dtype_to_torch (dtype ):
@@ -197,18 +201,18 @@ def __init__(self):
197
201
self ._v_scale_float = 1.0
198
202
199
203
200
- def run_attention_backend (backend_name : str , kv_cache_spec : FullAttentionSpec ,
204
+ def run_attention_backend (backend : _Backend , kv_cache_spec : FullAttentionSpec ,
201
205
vllm_config , device : torch .device ,
202
206
common_attn_metadata : CommonAttentionMetadata ,
203
207
query : torch .Tensor , key : torch .Tensor ,
204
208
value : torch .Tensor ,
205
209
kv_cache : torch .Tensor ) -> torch .Tensor :
206
210
"""Run attention computation using the specified backend's AttentionImpl."""
207
211
208
- builder_cls , impl_cls = get_attention_backend (backend_name )
212
+ builder_cls , impl_cls = get_attention_backend (backend )
209
213
210
214
# Mock flashinfer's get_per_layer_parameters if needed
211
- if backend_name == "flashinfer" :
215
+ if backend == _Backend . FLASHINFER_VLLM_V1 :
212
216
import unittest .mock
213
217
214
218
from vllm .v1 .attention .backends .flashinfer import PerLayerParameters
@@ -417,7 +421,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
417
421
# [num_blocks, 2, block_size, num_kv_heads, head_size]
418
422
# Select the appropriate KV cache format for each backend
419
423
kv_cache_for_backend = kv_cache
420
- if backend_name == "flashinfer" :
424
+ if backend_name == _Backend . FLASHINFER_VLLM_V1 :
421
425
kv_cache_for_backend = kv_cache .transpose (0 , 1 )
422
426
423
427
backend_output = run_attention_backend (backend_name , kv_cache_spec ,
@@ -440,17 +444,12 @@ def test_backend_correctness(batch_spec_name: str, model: str):
440
444
441
445
# Check numerical similarity
442
446
rtol = 1e-2
443
- atol = 1e -3
447
+ atol = 5e -3
444
448
445
- # Flashinfer and Flex_attention may have slightly different
446
- # numerical behavior
447
- if backend_name == "flashinfer" :
448
- atol = 5e-3
449
-
450
- if backend_name == "flex_attention" :
451
- atol = 5e-1 # TODO: figuure out why flex_attention has such large
452
- # numerical differences for
453
- # medium_decode, medium_prefill, mixed_medium
449
+ if backend_name == _Backend .FLEX_ATTENTION :
450
+ atol = 5e-1 # TODO: figure out why flex_attention has such large
451
+ # numerical differences for medium_decode, medium_prefill,
452
+ # mixed_medium
454
453
455
454
max_diff = torch .max (torch .abs (backend_output - sdpa_output )).item ()
456
455
max_rel_diff = torch .max (
0 commit comments