Skip to content

Commit 9fb52e5

Browse files
[V1] Support any head size for FlexAttention backend (#20467)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent e202dd2 commit 9fb52e5

File tree

20 files changed

+202
-118
lines changed

20 files changed

+202
-118
lines changed

.buildkite/scripts/hardware_ci/run-amd-test.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,9 @@ fi
107107

108108
if [[ $commands == *" kernels/attention"* ]]; then
109109
commands="${commands} \
110-
--ignore=kernels/attention/stest_attention_selector.py \
110+
--ignore=kernels/attention/test_attention_selector.py \
111111
--ignore=kernels/attention/test_blocksparse_attention.py \
112112
--ignore=kernels/attention/test_encoder_decoder_attn.py \
113-
--ignore=kernels/attention/test_attention_selector.py \
114113
--ignore=kernels/attention/test_flash_attn.py \
115114
--ignore=kernels/attention/test_flashinfer.py \
116115
--ignore=kernels/attention/test_prefix_prefill.py \

docs/models/supported_models.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,6 @@ Specified using `--task generate`.
626626
!!! note
627627
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
628628

629-
!!! note
630-
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
631-
632629
!!! note
633630
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
634631

@@ -671,11 +668,8 @@ Specified using `--task generate`.
671668
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
672669

673670
!!! note
674-
To use Qwen2.5-Omni, you have to install Hugging Face Transformers library from source via
675-
`pip install git+https://github.com/huggingface/transformers.git`.
676-
677-
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
678-
`--mm-processor-kwargs '{"use_audio_in_video": true}'`.
671+
For Qwen2.5-Omni, reading audio from video pre-processing (`--mm-processor-kwargs '{"use_audio_in_video": true}'`)
672+
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
679673

680674
#### Transcription
681675

examples/offline_inference/vision_language.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
9898
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
9999
prompts = [f"Question: {question} Answer:" for question in questions]
100100
engine_args = EngineArgs(
101-
model="Salesforce/blip2-opt-6.7b",
101+
model="Salesforce/blip2-opt-2.7b",
102102
limit_mm_per_prompt={modality: 1},
103103
)
104104

@@ -971,7 +971,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
971971
)
972972

973973

974-
# Qwen
974+
# Qwen-VL
975975
def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
976976
assert modality == "image"
977977

tests/kernels/attention/test_attention_selector.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_env(
172172
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
173173
assert backend.get_name() == expected
174174
else:
175-
backend = get_attn_backend(16,
175+
backend = get_attn_backend(32,
176176
torch.float16,
177177
torch.float16,
178178
block_size,
@@ -181,6 +181,17 @@ def test_env(
181181
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
182182
assert backend.get_name() == expected
183183

184+
if use_v1:
185+
backend = get_attn_backend(16,
186+
torch.float16,
187+
torch.float16,
188+
block_size,
189+
False,
190+
use_mla=use_mla)
191+
assert backend.get_name() == "FLEX_ATTENTION", (
192+
"Should fallback to FlexAttention if head size is "
193+
"not supported by FlashAttention")
194+
184195

185196
@pytest.mark.parametrize("device", ["cpu", "cuda"])
186197
@pytest.mark.parametrize("use_v1", [True, False])

tests/models/multimodal/generation/test_common.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
3434

3535
REQUIRES_V0_MODELS = [
36-
# V1 Test: no way to fall back for head_dim = 80
37-
# https://github.com/vllm-project/vllm/issues/14524
38-
"qwen_vl",
3936
# V1 Test: not enough KV cache space in C1.
4037
"fuyu",
4138
]
@@ -221,8 +218,7 @@
221218
marks=[large_gpu_mark(min_gb=32)],
222219
),
223220
"blip2": VLMTestInfo(
224-
# TODO: Change back to 2.7b once head_dim = 80 is supported
225-
models=["Salesforce/blip2-opt-6.7b"],
221+
models=["Salesforce/blip2-opt-2.7b"],
226222
test_type=VLMTestType.IMAGE,
227223
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
228224
img_idx_to_prompt=lambda idx: "",
@@ -340,8 +336,7 @@
340336
"h2ovl": VLMTestInfo(
341337
models = [
342338
"h2oai/h2ovl-mississippi-800m",
343-
# TODO: Re-enable once head_dim = 80 is supported
344-
# "h2oai/h2ovl-mississippi-2b",
339+
"h2oai/h2ovl-mississippi-2b",
345340
],
346341
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
347342
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501

tests/models/quantization/test_gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def gguf_model(self):
8383
QWEN2_CONFIG,
8484
PHI3_CONFIG,
8585
GPT2_CONFIG,
86-
# STABLELM_CONFIG, # enable this when v1 support head_size=80
86+
STABLELM_CONFIG,
8787
DOLPHIN_CONFIG,
8888
# STARCODER_CONFIG, # broken
8989
]

tests/models/registry.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ def check_available_online(
240240
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
241241
trust_remote_code=True),
242242
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
243-
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
243+
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
244244
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
245+
# Blocksparse attention not supported in V1 yet
245246
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
246247
trust_remote_code=True,
247248
v0_only=True),
@@ -258,10 +259,8 @@ def check_available_online(
258259
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
259260
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
260261
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
261-
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
262-
v0_only=True),
263-
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
264-
v0_only=True),
262+
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
263+
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
265264
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
266265
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
267266
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@@ -330,8 +329,7 @@ def check_available_online(
330329
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
331330
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
332331
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
333-
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
334-
v0_only=True),
332+
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
335333
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
336334
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
337335
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
@@ -359,8 +357,7 @@ def check_available_online(
359357
trust_remote_code=True),
360358
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
361359
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
362-
trust_remote_code=True,
363-
v0_only=True),
360+
trust_remote_code=True),
364361
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
365362
max_model_len=10240),
366363
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",

tests/models/test_initialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
2222
model_info.check_transformers_version(on_fail="skip")
2323

2424
# FIXME: Possible memory leak in the previous tests?
25-
if model_arch == "GraniteSpeechForConditionalGeneration":
25+
if model_arch in ("GraniteSpeechForConditionalGeneration",
26+
"KimiVLForConditionalGeneration"):
2627
pytest.skip("Avoid OOM")
2728

2829
# Avoid OOM and reduce initialization time by only using 1 layer

vllm/attention/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,8 @@ def __init__(
310310
# currently, only torch_sdpa is supported on rocm
311311
self.attn_backend = _Backend.TORCH_SDPA
312312
else:
313-
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
313+
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
314+
_Backend.FLEX_ATTENTION):
314315
backend = _Backend.XFORMERS
315316

316317
self.attn_backend = backend if backend in {

vllm/attention/selector.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from contextlib import contextmanager
66
from functools import cache
7-
from typing import Generator, Optional, Type
7+
from typing import Generator, Optional, Union
88

99
import torch
1010

@@ -79,6 +79,33 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
7979
return forced_attn_backend
8080

8181

82+
def supports_head_size(
83+
attn_backend: Union[str, type[AttentionBackend]],
84+
head_size: int,
85+
) -> bool:
86+
if isinstance(attn_backend, str):
87+
try:
88+
attn_backend = resolve_obj_by_qualname(attn_backend)
89+
except ImportError:
90+
return False
91+
92+
assert isinstance(attn_backend, type)
93+
94+
# TODO: Update the interface once V0 is removed
95+
if get_supported_head_sizes := getattr(attn_backend,
96+
"get_supported_head_sizes", None):
97+
return head_size in get_supported_head_sizes()
98+
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
99+
try:
100+
validate_head_size(head_size)
101+
return True
102+
except Exception:
103+
return False
104+
105+
raise NotImplementedError(f"{attn_backend.__name__} does not support "
106+
"head size validation")
107+
108+
82109
def get_attn_backend(
83110
head_size: int,
84111
dtype: torch.dtype,
@@ -87,7 +114,7 @@ def get_attn_backend(
87114
is_attention_free: bool,
88115
is_blocksparse: bool = False,
89116
use_mla: bool = False,
90-
) -> Type[AttentionBackend]:
117+
) -> type[AttentionBackend]:
91118
"""Selects which attention backend to use and lazily imports it."""
92119
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
93120
# value to be returned from the cache if the value changes between calls.
@@ -115,7 +142,7 @@ def _cached_get_attn_backend(
115142
is_blocksparse: bool = False,
116143
use_v1: bool = False,
117144
use_mla: bool = False,
118-
) -> Type[AttentionBackend]:
145+
) -> type[AttentionBackend]:
119146
if is_blocksparse:
120147
logger.info("Using BlocksparseFlashAttention backend.")
121148
from vllm.attention.backends.blocksparse_attn import (

0 commit comments

Comments
 (0)