Skip to content

Commit 254b1e9

Browse files
hax0r31337WorldExplored
authored andcommitted
[Bugfix] Voxtral on Blackwell GPUs (RTX 50 series) (#21077)
Signed-off-by: hax0r31337 <liulihaocaiqwq@gmail.com>
1 parent d7eeda6 commit 254b1e9

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

vllm/attention/layer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,42 @@
1616
has_kv_transfer_group,
1717
is_v1_kv_transfer_group)
1818
from vllm.forward_context import ForwardContext, get_forward_context
19+
from vllm.logger import init_logger
1920
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
2021
from vllm.model_executor.layers.quantization.base_config import (
2122
QuantizationConfig)
2223
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
2324
from vllm.platforms import _Backend, current_platform
2425
from vllm.utils import direct_register_custom_op
2526

27+
logger = init_logger(__name__)
28+
USE_XFORMERS_OPS = None
29+
30+
31+
def check_xformers_availability():
32+
global USE_XFORMERS_OPS
33+
if USE_XFORMERS_OPS is not None:
34+
return USE_XFORMERS_OPS
35+
36+
if current_platform.is_cuda() and current_platform.has_device_capability(
37+
100):
38+
# Xformers FA is not compatible with B200
39+
USE_XFORMERS_OPS = False
40+
else:
41+
try:
42+
from importlib.util import find_spec
43+
44+
find_spec("xformers.ops")
45+
USE_XFORMERS_OPS = True
46+
except ImportError:
47+
USE_XFORMERS_OPS = False
48+
49+
# the warning only needs to be shown once
50+
if not USE_XFORMERS_OPS:
51+
logger.warning("Xformers is not available, falling back.")
52+
53+
return USE_XFORMERS_OPS
54+
2655

2756
class Attention(nn.Module):
2857
"""Attention layer.
@@ -314,6 +343,10 @@ def __init__(
314343
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
315344
} else _Backend.TORCH_SDPA
316345

346+
if (self.attn_backend == _Backend.XFORMERS
347+
and not check_xformers_availability()):
348+
self.attn_backend = _Backend.TORCH_SDPA
349+
317350
def forward(
318351
self,
319352
query: torch.Tensor,

0 commit comments

Comments
 (0)