Skip to content

Commit 5f063a8

Browse files
authored
[bugfix] add supports_v1 platform interface (#15417)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
1 parent 5d8e1c9 commit 5f063a8

File tree

5 files changed

+31
-7
lines changed

5 files changed

+31
-7
lines changed

vllm/engine/arg_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,9 +1666,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
16661666
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
16671667
return False
16681668

1669-
# No support for device type other than CUDA, AMD (experiemntal) or
1670-
# TPU (experimental) so far.
1671-
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
1669+
# Platforms must decide if they can support v1 for this model
1670+
if not current_platform.supports_v1(model_config=model_config):
16721671
_raise_or_fallback(
16731672
feature_name=f"device type={current_platform.device_type}",
16741673
recommend_to_remove=False)

vllm/platforms/cuda.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
2121

2222
if TYPE_CHECKING:
23-
from vllm.config import VllmConfig
23+
from vllm.config import ModelConfig, VllmConfig
2424
else:
25+
ModelConfig = None
2526
VllmConfig = None
2627

2728
logger = init_logger(__name__)
@@ -303,6 +304,10 @@ def get_device_communicator_cls(cls) -> str:
303304
def supports_fp8(cls) -> bool:
304305
return cls.has_device_capability(89)
305306

307+
@classmethod
308+
def supports_v1(cls, model_config: ModelConfig) -> bool:
309+
return True
310+
306311

307312
# NVML utils
308313
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from vllm.logger import init_logger
1313

1414
if TYPE_CHECKING:
15-
from vllm.config import VllmConfig
15+
from vllm.config import ModelConfig, VllmConfig
1616
from vllm.utils import FlexibleArgumentParser
1717
else:
18+
ModelConfig = None
1819
VllmConfig = None
1920
FlexibleArgumentParser = None
2021

@@ -371,6 +372,13 @@ def use_all_gather(cls) -> bool:
371372
or parallel_config.distributed_executor_backend
372373
== "external_launcher")
373374

375+
@classmethod
376+
def supports_v1(cls, model_config: ModelConfig) -> bool:
377+
"""Returns whether the current platform can support v1 for the supplied
378+
model configuration.
379+
"""
380+
return False
381+
374382

375383
class UnspecifiedPlatform(Platform):
376384
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
1313

1414
if TYPE_CHECKING:
15-
from vllm.config import VllmConfig
15+
from vllm.config import ModelConfig, VllmConfig
1616
else:
17+
ModelConfig = None
1718
VllmConfig = None
1819

1920
logger = init_logger(__name__)
@@ -249,3 +250,8 @@ def fp8_dtype(cls) -> torch.dtype:
249250
return torch.float8_e4m3fnuz
250251
else:
251252
return torch.float8_e4m3fn
253+
254+
@classmethod
255+
def supports_v1(cls, model_config: ModelConfig) -> bool:
256+
# V1 support on AMD gpus is experimental
257+
return True

vllm/platforms/tpu.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from .interface import Platform, PlatformEnum, _Backend
1111

1212
if TYPE_CHECKING:
13-
from vllm.config import VllmConfig
13+
from vllm.config import ModelConfig, VllmConfig
1414
else:
15+
ModelConfig = None
1516
VllmConfig = None
1617

1718
logger = init_logger(__name__)
@@ -127,3 +128,8 @@ def get_device_communicator_cls(cls) -> str:
127128
@classmethod
128129
def use_all_gather(cls) -> bool:
129130
return True
131+
132+
@classmethod
133+
def supports_v1(cls, model_config: ModelConfig) -> bool:
134+
# V1 support on TPU is experimental
135+
return True

0 commit comments

Comments
 (0)