Skip to content

Commit 5f1ac1e

Browse files
authored
Revert "[v1] Add fp32 support to v1 engine through flex attn" (vllm-project#19404)
1 parent 9368cc9 commit 5f1ac1e

File tree

4 files changed

+7
-38
lines changed

4 files changed

+7
-38
lines changed

tests/kernels/attention/test_attention_selector.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -183,34 +183,6 @@ def test_env(
183183
assert backend.get_name() == expected
184184

185185

186-
@pytest.mark.parametrize("device", ["cpu", "cuda"])
187-
@pytest.mark.parametrize("use_v1", [True, False])
188-
def test_fp32_fallback(
189-
device: str,
190-
use_v1: bool,
191-
monkeypatch: pytest.MonkeyPatch,
192-
):
193-
"""Test attention backend selection with fp32."""
194-
with monkeypatch.context() as m:
195-
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
196-
197-
if device == "cpu":
198-
with patch("vllm.attention.selector.current_platform",
199-
CpuPlatform()):
200-
backend = get_attn_backend(16, torch.float32, torch.float32,
201-
16, False)
202-
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
203-
if use_v1 else "TORCH_SDPA")
204-
205-
elif device == "cuda":
206-
with patch("vllm.attention.selector.current_platform",
207-
CudaPlatform()):
208-
backend = get_attn_backend(16, torch.float32, torch.float32,
209-
16, False)
210-
assert (backend.get_name() == "FLEX_ATTENTION"
211-
if use_v1 else "XFORMERS")
212-
213-
214186
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
215187
"""Test FlashAttn validation."""
216188
# TODO: When testing for v1, pipe in `use_v1` as an argument to

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import random
55

66
import pytest
7-
import torch
87

98
from vllm.attention import Attention
109
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
@@ -400,7 +399,6 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
400399

401400

402401
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
403-
torch.set_default_dtype(torch.float16)
404402
layer_0 = "model.layers.0.self_attn.attn"
405403
layer_1 = "model.layers.1.self_attn.attn"
406404
error_msg = f"{layer_1} must come before the current layer"
@@ -429,7 +427,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
429427

430428

431429
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
432-
torch.set_default_dtype(torch.float16)
433430
layer_0 = "model.layers.0.self_attn.attn"
434431
layer_1 = "model.layers.1.self_attn.attn"
435432
invalid_layer = "model.layers.0.cross_attn.attn"
@@ -458,7 +455,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
458455

459456

460457
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
461-
torch.set_default_dtype(torch.float16)
462458
layer_0 = "model.layers.0.self_attn.attn"
463459
layer_1 = "model.layers.1.self_attn.attn"
464460
error_msg = f"{layer_1} cannot be the same as the current layer"
@@ -487,7 +483,6 @@ def test_init_kv_cache_with_kv_sharing_target_same_as_current():
487483

488484

489485
def test_init_kv_cache_without_kv_sharing():
490-
torch.set_default_dtype(torch.float16)
491486
layer_0 = "model.layers.0.self_attn.attn"
492487
layer_1 = "model.layers.1.self_attn.attn"
493488
vllm_config = get_vllm_config()
@@ -555,7 +550,6 @@ def test_init_kv_cache_without_kv_sharing():
555550

556551

557552
def test_init_kv_cache_with_kv_sharing_valid():
558-
torch.set_default_dtype(torch.float16)
559553
layer_0 = "model.layers.0.self_attn.attn"
560554
layer_1 = "model.layers.1.self_attn.attn"
561555
vllm_config = get_vllm_config()

vllm/engine/arg_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13371337
recommend_to_remove=False)
13381338
return False
13391339

1340+
# Only Fp16 and Bf16 dtypes since we only support FA.
1341+
V1_SUPPORTED_DTYPES = [torch.bfloat16, torch.float16]
1342+
if model_config.dtype not in V1_SUPPORTED_DTYPES:
1343+
_raise_or_fallback(feature_name=f"--dtype {model_config.dtype}",
1344+
recommend_to_remove=False)
1345+
return False
1346+
13401347
# No Embedding Models so far.
13411348
if model_config.task not in ["generate"]:
13421349
_raise_or_fallback(feature_name=f"--task {model_config.task}",

vllm/platforms/cuda.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
233233
logger.info_once("Using Triton backend on V1 engine.")
234234
return ("vllm.v1.attention.backends."
235235
"triton_attn.TritonAttentionBackend")
236-
if dtype not in (torch.float16, torch.bfloat16):
237-
logger.info_once(
238-
f"Using FlexAttenion backend for {dtype} on V1 engine.")
239-
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
240236
if cls.is_device_capability(100):
241237
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
242238
try:

0 commit comments

Comments
 (0)