Skip to content

Commit 61e2082

Browse files
authored
Fall back if flashinfer comm module not found (#20936)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
1 parent 55e1c66 commit 61e2082

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from .vllm_inductor_pass import VllmInductorPass
2121

2222
if find_spec("flashinfer"):
23-
import flashinfer.comm as flashinfer_comm
24-
25-
flashinfer_comm = (flashinfer_comm if hasattr(
26-
flashinfer_comm, "trtllm_allreduce_fusion") else None)
23+
try:
24+
import flashinfer.comm as flashinfer_comm
25+
flashinfer_comm = (flashinfer_comm if hasattr(
26+
flashinfer_comm, "trtllm_allreduce_fusion") else None)
27+
except ImportError:
28+
flashinfer_comm = None
2729
else:
2830
flashinfer_comm = None
2931
from vllm.platforms import current_platform
@@ -411,7 +413,8 @@ def __init__(self, config: VllmConfig, max_token_num: int):
411413
use_fp32_lamport = self.model_dtype == torch.float32
412414
if flashinfer_comm is None:
413415
logger.warning(
414-
"Flashinfer is not installed, skipping allreduce fusion pass")
416+
"Flashinfer is not installed or comm module not found, "
417+
"skipping allreduce fusion pass")
415418
return
416419
# Check if the world size is supported
417420
if self.tp_size not in _FI_MAX_SIZES:

0 commit comments

Comments
 (0)