Skip to content

Commit 4b9d13d

Browse files
authored
Quick Fix by adding conditional import for flash_attn_varlen_func in flash_attn (#20143)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
1 parent 7719456 commit 4b9d13d

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

vllm/attention/utils/fa_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
6666
def flash_attn_supports_fp8() -> bool:
6767
return get_flash_attn_version() == 3 and \
6868
current_platform.get_device_capability().major == 9
69+
70+
71+
def is_flash_attn_varlen_func_available() -> bool:
72+
return current_platform.is_cuda() or current_platform.is_xpu()

vllm/v1/attention/backends/flash_attn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
from vllm.attention.layer import Attention
1515
from vllm.attention.ops.merge_attn_states import merge_attn_states
1616
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
17-
flash_attn_varlen_func,
1817
get_flash_attn_version,
19-
get_scheduler_metadata,
20-
reshape_and_cache_flash)
18+
is_flash_attn_varlen_func_available)
19+
20+
if is_flash_attn_varlen_func_available():
21+
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
22+
get_scheduler_metadata,
23+
reshape_and_cache_flash)
24+
2125
from vllm.config import VllmConfig, get_layers_from_vllm_config
2226
from vllm.logger import init_logger
2327
from vllm.utils import cdiv

0 commit comments

Comments
 (0)