Skip to content

Commit b808919

Browse files
authored
[v1] Add fp32 support to v1 engine through flex attn (#19319)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 770e5dc commit b808919

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,34 @@ def test_env(
183183
assert backend.get_name() == expected
184184

185185

186+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
187+
@pytest.mark.parametrize("use_v1", [True, False])
188+
def test_fp32_fallback(
189+
device: str,
190+
use_v1: bool,
191+
monkeypatch: pytest.MonkeyPatch,
192+
):
193+
"""Test attention backend selection with fp32."""
194+
with monkeypatch.context() as m:
195+
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
196+
197+
if device == "cpu":
198+
with patch("vllm.attention.selector.current_platform",
199+
CpuPlatform()):
200+
backend = get_attn_backend(16, torch.float32, torch.float32,
201+
16, False)
202+
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
203+
if use_v1 else "TORCH_SDPA")
204+
205+
elif device == "cuda":
206+
with patch("vllm.attention.selector.current_platform",
207+
CudaPlatform()):
208+
backend = get_attn_backend(16, torch.float32, torch.float32,
209+
16, False)
210+
assert (backend.get_name() == "FLEX_ATTENTION"
211+
if use_v1 else "XFORMERS")
212+
213+
186214
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
187215
"""Test FlashAttn validation."""
188216
# TODO: When testing for v1, pipe in `use_v1` as an argument to

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55

66
import pytest
7+
import torch
78

89
from vllm.attention import Attention
910
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -399,6 +400,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
399400

400401

401402
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
403+
torch.set_default_dtype(torch.float16)
402404
layer_0 = "model.layers.0.self_attn.attn"
403405
layer_1 = "model.layers.1.self_attn.attn"
404406
error_msg = f"{layer_1} must come before the current layer"
@@ -427,6 +429,7 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
427429

428430

429431
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
432+
torch.set_default_dtype(torch.float16)
430433
layer_0 = "model.layers.0.self_attn.attn"
431434
layer_1 = "model.layers.1.self_attn.attn"
432435
invalid_layer = "model.layers.0.cross_attn.attn"
@@ -455,6 +458,7 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
455458

456459

457460
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
461+
torch.set_default_dtype(torch.float16)
458462
layer_0 = "model.layers.0.self_attn.attn"
459463
layer_1 = "model.layers.1.self_attn.attn"
460464
error_msg = f"{layer_1} cannot be the same as the current layer"
@@ -483,6 +487,7 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
483487

484488

485489
def test_init_kv_cache_without_kv_sharing():
490+
torch.set_default_dtype(torch.float16)
486491
layer_0 = "model.layers.0.self_attn.attn"
487492
layer_1 = "model.layers.1.self_attn.attn"
488493
vllm_config = get_vllm_config()
@@ -550,6 +555,7 @@ def test_init_kv_cache_without_kv_sharing():
550555

551556

552557
def test_init_kv_cache_with_kv_sharing_valid():
558+
torch.set_default_dtype(torch.float16)
553559
layer_0 = "model.layers.0.self_attn.attn"
554560
layer_1 = "model.layers.1.self_attn.attn"
555561
vllm_config = get_vllm_config()

vllm/engine/arg_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,13 +1337,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13371337
recommend_to_remove=False)
13381338
return False
13391339

1340-
# Only Fp16 and Bf16 dtypes since we only support FA.
1341-
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
1342-
if model_config.dtype not in V1_SUPPORTED_DTYPES:
1343-
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
1344-
recommend_to_remove=False)
1345-
return False
1346-
13471340
# No Embedding Models so far.
13481341
if model_config.task not in ["generate"]:
13491342
_raise_or_fallback(feature_name=f"--task {model_config.task}",

vllm/platforms/cuda.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
233233
logger.info_once("Using Triton backend on V1 engine.")
234234
return ("vllm.v1.attention.backends."
235235
"triton_attn.TritonAttentionBackend")
236+
if dtype not in (torch.float16, torch.bfloat16):
237+
logger.info_once(
238+
f"Using FlexAttenion backend for {dtype} on V1 engine.")
239+
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
236240
if cls.is_device_capability(100):
237241
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
238242
try:

0 commit comments

Comments
 (0)