Skip to content

Commit 834b496

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent da92d38 commit 834b496

File tree

3 files changed

+33
-28
lines changed

3 files changed

+33
-28
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@
55
import pytest
66
import torch
77

8-
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
8+
from tests.v1.attention.utils import (BatchSpec, _Backend,
9+
create_common_attn_metadata,
910
create_standard_kv_cache_spec,
1011
create_vllm_config,
1112
get_attention_backend)
1213
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
1314
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1415
from vllm.v1.kv_cache_interface import FullAttentionSpec
1516

16-
BACKENDS_TO_TEST = ["flash_attn", "flashinfer", "flex_attention"]
17+
BACKENDS_TO_TEST = [
18+
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
19+
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1
20+
]
1721

1822
# Remove flashinfer from the list if it's not available
1923
try:
2024
import flashinfer # noqa: F401
2125
except ImportError:
22-
BACKENDS_TO_TEST.remove("flashinfer")
26+
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1)
2327

2428

2529
def _convert_dtype_to_torch(dtype):
@@ -197,18 +201,18 @@ def __init__(self):
197201
self._v_scale_float = 1.0
198202

199203

200-
def run_attention_backend(backend_name: str, kv_cache_spec: FullAttentionSpec,
204+
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
201205
vllm_config, device: torch.device,
202206
common_attn_metadata: CommonAttentionMetadata,
203207
query: torch.Tensor, key: torch.Tensor,
204208
value: torch.Tensor,
205209
kv_cache: torch.Tensor) -> torch.Tensor:
206210
"""Run attention computation using the specified backend's AttentionImpl."""
207211

208-
builder_cls, impl_cls = get_attention_backend(backend_name)
212+
builder_cls, impl_cls = get_attention_backend(backend)
209213

210214
# Mock flashinfer's get_per_layer_parameters if needed
211-
if backend_name == "flashinfer":
215+
if backend == _Backend.FLASHINFER_VLLM_V1:
212216
import unittest.mock
213217

214218
from vllm.v1.attention.backends.flashinfer import PerLayerParameters
@@ -417,7 +421,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
417421
# [num_blocks, 2, block_size, num_kv_heads, head_size]
418422
# Select the appropriate KV cache format for each backend
419423
kv_cache_for_backend = kv_cache
420-
if backend_name == "flashinfer":
424+
if backend_name == _Backend.FLASHINFER_VLLM_V1:
421425
kv_cache_for_backend = kv_cache.transpose(0, 1)
422426

423427
backend_output = run_attention_backend(backend_name, kv_cache_spec,
@@ -440,17 +444,12 @@ def test_backend_correctness(batch_spec_name: str, model: str):
440444

441445
# Check numerical similarity
442446
rtol = 1e-2
443-
atol = 1e-3
447+
atol = 5e-3
444448

445-
# Flashinfer and Flex_attention may have slightly different
446-
# numerical behavior
447-
if backend_name == "flashinfer":
448-
atol = 5e-3
449-
450-
if backend_name == "flex_attention":
451-
atol = 5e-1 # TODO: figuure out why flex_attention has such large
452-
# numerical differences for
453-
# medium_decode, medium_prefill, mixed_medium
449+
if backend_name == _Backend.FLEX_ATTENTION:
450+
atol = 5e-1 # TODO: figure out why flex_attention has such large
451+
# numerical differences for medium_decode, medium_prefill,
452+
# mixed_medium
454453

455454
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
456455
max_rel_diff = torch.max(

tests/v1/attention/utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
1212
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
1313
SchedulerConfig, VllmConfig)
14+
from vllm.platforms import _Backend
15+
from vllm.utils import resolve_obj_by_qualname
1416
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1517
from vllm.v1.kv_cache_interface import FullAttentionSpec
1618

@@ -92,7 +94,7 @@ def create_common_attn_metadata(
9294
)
9395

9496

95-
def get_attention_backend(backend_name: str):
97+
def get_attention_backend(backend_name: _Backend):
9698
"""Set up attention backend classes for testing.
9799
98100
Args:
@@ -103,23 +105,23 @@ def get_attention_backend(backend_name: str):
103105
Tuple of (backend_builder_class, backend_impl_class)
104106
"""
105107
backend_map = {
106-
"flash_attn":
107-
("vllm.v1.attention.backends.flash_attn", "FlashAttentionBackend"),
108-
"flashinfer":
109-
("vllm.v1.attention.backends.flashinfer", "FlashInferBackend"),
110-
"flex_attention":
111-
("vllm.v1.attention.backends.flex_attention", "FlexAttentionBackend"),
108+
_Backend.FLASH_ATTN_VLLM_V1:
109+
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
110+
_Backend.FLASHINFER_VLLM_V1:
111+
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
112+
_Backend.FLEX_ATTENTION:
113+
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
114+
_Backend.TRITON_ATTN_VLLM_V1:
115+
"vllm.v1.attention.backends.triton_attn.TritonAttnBackend",
112116
}
113117

114118
if backend_name not in backend_map:
115119
raise ValueError(f"Unknown backend: {backend_name}")
116120

117-
module_name, backend_class_name = backend_map[backend_name]
121+
backend_class_name = backend_map[backend_name]
118122

119123
try:
120-
import importlib
121-
module = importlib.import_module(module_name)
122-
backend_class = getattr(module, backend_class_name)
124+
backend_class = resolve_obj_by_qualname(backend_class_name)
123125
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
124126
except ImportError as e:
125127
pytest.skip(f"{backend_name} not available: {e}")

vllm/v1/attention/backends/flash_attn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def build(self,
208208
common_prefix_len: int,
209209
common_attn_metadata: CommonAttentionMetadata,
210210
fast_build: bool = False) -> FlashAttentionMetadata:
211+
"""
212+
fast_build disables AOT scheduling, used when there will be few
213+
iterations i.e. spec-decode
214+
"""
211215
num_reqs = common_attn_metadata.num_reqs
212216
num_actual_tokens = common_attn_metadata.num_actual_tokens
213217
max_query_len = common_attn_metadata.max_query_len

0 commit comments

Comments
 (0)