Skip to content

Commit e8cc53a

Browse files
[Misc] Log the reason for falling back to FlexAttention (#20699)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent a4851cf commit e8cc53a

File tree

10 files changed

+104
-32
lines changed

10 files changed

+104
-32
lines changed

vllm/attention/selector.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import os
55
from contextlib import contextmanager
6+
from dataclasses import dataclass
67
from functools import cache
78
from typing import Generator, Optional, Union
89

@@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
7980
return forced_attn_backend
8081

8182

82-
def supports_head_size(
83+
@dataclass(frozen=True)
84+
class _IsSupported:
85+
can_import: bool
86+
head_size: bool
87+
dtype: bool
88+
89+
def __bool__(self) -> bool:
90+
return self.can_import and self.head_size and self.dtype
91+
92+
93+
def is_attn_backend_supported(
8394
attn_backend: Union[str, type[AttentionBackend]],
8495
head_size: int,
85-
) -> bool:
96+
dtype: torch.dtype,
97+
*,
98+
allow_import_error: bool = True,
99+
) -> _IsSupported:
86100
if isinstance(attn_backend, str):
87101
try:
88102
attn_backend = resolve_obj_by_qualname(attn_backend)
89103
except ImportError:
90-
return False
104+
if not allow_import_error:
105+
raise
106+
107+
return _IsSupported(can_import=False, head_size=False, dtype=False)
91108

92109
assert isinstance(attn_backend, type)
93110

94111
# TODO: Update the interface once V0 is removed
95112
if get_supported_head_sizes := getattr(attn_backend,
96113
"get_supported_head_sizes", None):
97-
return head_size in get_supported_head_sizes()
98-
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
114+
is_head_size_supported = head_size in get_supported_head_sizes()
115+
elif validate_head_size := getattr(attn_backend, "validate_head_size",
116+
None):
99117
try:
100118
validate_head_size(head_size)
101-
return True
119+
is_head_size_supported = True
102120
except Exception:
103-
return False
121+
is_head_size_supported = False
122+
else:
123+
raise NotImplementedError(f"{attn_backend.__name__} does not support "
124+
"head size validation")
125+
126+
if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
127+
None):
128+
is_dtype_supported = dtype in get_supported_dtypes()
129+
else:
130+
raise NotImplementedError(f"{attn_backend.__name__} does not support "
131+
"dtype validation")
104132

105-
raise NotImplementedError(f"{attn_backend.__name__} does not support "
106-
"head size validation")
133+
return _IsSupported(
134+
can_import=True,
135+
head_size=is_head_size_supported,
136+
dtype=is_dtype_supported,
137+
)
107138

108139

109140
def get_attn_backend(

vllm/platforms/cuda.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -259,43 +259,56 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
259259
logger.info_once("Using Flash Attention backend on V1 engine.")
260260
return FLASH_ATTN_V1
261261

262-
from vllm.attention.selector import supports_head_size
262+
from vllm.attention.selector import is_attn_backend_supported
263263

264264
# Default backends for V1 engine
265-
# FP32 is only supported by FlexAttention
266-
if dtype not in (torch.float16, torch.bfloat16):
267-
logger.info_once(
268-
"Using FlexAttention backend for %s on V1 engine.",
269-
dtype,
270-
)
271-
return FLEX_ATTENTION_V1
272-
273265
# Prefer FlashInfer for Blackwell GPUs if installed
274-
if cls.is_device_capability(100) and \
275-
supports_head_size(FLASHINFER_V1, head_size):
276-
try:
277-
import flashinfer # noqa: F401
278-
266+
if cls.is_device_capability(100):
267+
if is_default_backend_supported := is_attn_backend_supported(
268+
FLASHINFER_V1, head_size, dtype):
279269
from vllm.v1.attention.backends.utils import (
280270
set_kv_cache_layout)
271+
281272
logger.info_once(
282273
"Using FlashInfer backend with HND KV cache layout on "
283274
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
284275
set_kv_cache_layout("HND")
276+
285277
return FLASHINFER_V1
286-
except ImportError:
287-
logger.info_once(
278+
279+
if not is_default_backend_supported.can_import:
280+
logger.warning_once(
288281
"FlashInfer failed to import for V1 engine on "
289282
"Blackwell (SM 10.0) GPUs; it is recommended to "
290283
"install FlashInfer for better performance.")
291-
pass
284+
292285
# FlashAttention is the default for SM 8.0+ GPUs
293-
if cls.has_device_capability(80) and \
294-
supports_head_size(FLASH_ATTN_V1, head_size):
295-
logger.info_once("Using Flash Attention backend on V1 engine.")
296-
return FLASH_ATTN_V1
286+
if cls.has_device_capability(80):
287+
if is_default_backend_supported := is_attn_backend_supported(
288+
FLASH_ATTN_V1, head_size, dtype,
289+
allow_import_error=False):
290+
logger.info_once("Using Flash Attention backend on "
291+
"V1 engine.")
292+
return FLASH_ATTN_V1
293+
294+
# FlexAttention is the default for older GPUs
295+
else:
296+
logger.info_once("Using FlexAttention backend on V1 engine.")
297+
return FLEX_ATTENTION_V1
298+
299+
assert not is_default_backend_supported
300+
301+
use_flex_attention_reason = {}
302+
if not is_default_backend_supported.head_size:
303+
use_flex_attention_reason["head_size"] = head_size
304+
if not is_default_backend_supported.dtype:
305+
use_flex_attention_reason["dtype"] = dtype
297306

298-
logger.info_once("Using FlexAttention backend on V1 engine.")
307+
logger.info_once(
308+
"Using FlexAttention backend for %s on V1 engine.",
309+
", ".join(f"{k}={v}"
310+
for k, v in use_flex_attention_reason.items()),
311+
)
299312
return FLEX_ATTENTION_V1
300313

301314
# Backends for V0 engine

vllm/reasoning/hunyuan_a13b_reasoning_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import re
54
from collections.abc import Sequence
65
from typing import Optional, Union
76

7+
import regex as re
88
from transformers import PreTrainedTokenizerBase
99

1010
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
class TorchSDPABackend(AttentionBackend):
3838
accept_output_buffer: bool = False
3939

40+
@classmethod
41+
def get_supported_dtypes(cls) -> list[torch.dtype]:
42+
return [torch.float16, torch.bfloat16, torch.float32]
43+
4044
@classmethod
4145
def validate_head_size(cls, head_size: int) -> None:
4246
attn_impl = _get_paged_attn_impl()

vllm/v1/attention/backends/flash_attn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend):
4444

4545
accept_output_buffer: bool = True
4646

47+
@classmethod
48+
def get_supported_dtypes(cls) -> list[torch.dtype]:
49+
return [torch.float16, torch.bfloat16]
50+
4751
@classmethod
4852
def get_supported_head_sizes(cls) -> list[int]:
4953
return [32, 64, 96, 128, 160, 192, 224, 256]

vllm/v1/attention/backends/flashinfer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
4242
accept_output_buffer: bool = True
4343
cached_sm100a_supported: Optional[bool] = None
4444

45+
@classmethod
46+
def get_supported_dtypes(cls) -> list[torch.dtype]:
47+
return [torch.float16, torch.bfloat16]
48+
4549
@classmethod
4650
def get_supported_head_sizes(cls) -> list[int]:
4751
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157

vllm/v1/attention/backends/flex_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
4242
class FlexAttentionBackend(AttentionBackend):
4343
accept_output_buffer: bool = True
4444

45+
@classmethod
46+
def get_supported_dtypes(cls) -> list[torch.dtype]:
47+
return [torch.float16, torch.bfloat16, torch.float32]
48+
4549
@classmethod
4650
def validate_head_size(cls, head_size: int) -> None:
4751
return # FlexAttention supports any head size

vllm/v1/attention/backends/mla/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,10 @@ def get_kv_cache_shape(
262262
) -> tuple[int, ...]:
263263
return (num_blocks, block_size, head_size)
264264

265+
@classmethod
266+
def get_supported_dtypes(cls) -> list[torch.dtype]:
267+
return [torch.float16, torch.bfloat16]
268+
265269
@classmethod
266270
def get_supported_head_sizes(cls) -> list[int]:
267271
return [576]

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
314314

315315
accept_output_buffer: bool = True
316316

317+
@classmethod
318+
def get_supported_dtypes(cls) -> list[torch.dtype]:
319+
return [torch.float16, torch.bfloat16]
320+
317321
@classmethod
318322
def get_supported_head_sizes(cls) -> list[int]:
319323
return [32, 64, 96, 128, 160, 192, 224, 256]

vllm/v1/attention/backends/triton_attn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend):
190190

191191
accept_output_buffer: bool = True
192192

193+
@classmethod
194+
def get_supported_dtypes(cls) -> list[torch.dtype]:
195+
return [torch.float16, torch.bfloat16]
196+
193197
@classmethod
194198
def get_supported_head_sizes(cls) -> list[int]:
195199
return [32, 64, 96, 128, 160, 192, 224, 256]

0 commit comments

Comments
 (0)