Skip to content

Commit 54f60ba

Browse files
wangxiyuanminpeter
authored andcommitted
[Platform] Move platform check to right place (vllm-project#18470)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 661c031 commit 54f60ba

File tree

7 files changed

+71
-25
lines changed

7 files changed

+71
-25
lines changed

vllm/config.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
try_get_generation_config, uses_mrope)
4343
from vllm.transformers_utils.s3_utils import S3Model
4444
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
45-
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
45+
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
46+
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
47+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
48+
LayerBlockType, cuda_device_count_stateless,
4649
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
4750
random_uuid, resolve_obj_by_qualname)
4851

@@ -64,12 +67,6 @@
6467

6568
ConfigT = TypeVar("ConfigT", bound=ConfigType)
6669

67-
# This value is chosen to have a balance between ITL and TTFT. Note it is
68-
# not optimized for throughput.
69-
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
70-
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
71-
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
72-
7370
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
7471
"score", "reward", "transcription"]
7572

@@ -2074,28 +2071,28 @@ def __post_init__(self) -> None:
20742071
# so we don't reject sequences on account of a short
20752072
# max_num_batched_tokens.
20762073
self.max_num_batched_tokens = max(
2077-
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
2074+
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
20782075
else:
20792076
self.max_num_batched_tokens = (
2080-
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
2077+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
20812078
else:
20822079
# If max_model_len is too short, use
2083-
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
2080+
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
20842081
# for higher throughput.
20852082
self.max_num_batched_tokens = max(
2086-
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
2083+
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
20872084

20882085
if self.runner_type == "pooling":
20892086
# Choose specific value for higher throughput
20902087
self.max_num_batched_tokens = max(
20912088
self.max_num_batched_tokens,
2092-
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
2089+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
20932090
)
20942091
if self.is_multimodal_model:
20952092
# The value needs to be at least the number of multimodal tokens
20962093
self.max_num_batched_tokens = max(
20972094
self.max_num_batched_tokens,
2098-
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
2095+
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
20992096
)
21002097

21012098
# When using default settings,
@@ -4316,18 +4313,6 @@ def __post_init__(self):
43164313
"full_cuda_graph is not supported with "
43174314
"cascade attention. Disabling cascade attention.")
43184315
self.model_config.disable_cascade_attn = True
4319-
4320-
if self.model_config and self.model_config.use_mla and \
4321-
not (current_platform.is_cuda() or current_platform.is_rocm()):
4322-
logger.info(
4323-
"MLA is enabled on a non-GPU platform; forcing chunked "
4324-
"prefill and prefix caching to be disabled.")
4325-
self.scheduler_config.enable_chunked_prefill = False
4326-
self.scheduler_config.chunked_prefill_enabled = False
4327-
self.scheduler_config.max_num_batched_tokens = max(
4328-
self.scheduler_config.max_model_len,
4329-
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
4330-
43314316
if self.cache_config is not None:
43324317
self.cache_config.enable_prefix_caching = False
43334318

vllm/platforms/cpu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from vllm.logger import init_logger
12+
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1213

1314
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
1415

@@ -177,6 +178,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
177178
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
178179
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
179180

181+
if vllm_config.model_config and vllm_config.model_config.use_mla:
182+
logger.info(
183+
"MLA is enabled on a non-GPU platform; forcing chunked "
184+
"prefill and prefix caching to be disabled.")
185+
vllm_config.scheduler_config.enable_chunked_prefill = False
186+
vllm_config.scheduler_config.chunked_prefill_enabled = False
187+
vllm_config.scheduler_config.max_num_batched_tokens = max(
188+
vllm_config.scheduler_config.max_model_len,
189+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
190+
180191
@classmethod
181192
def is_pin_memory_available(cls) -> bool:
182193
logger.warning("Pin memory is not supported on CPU.")

vllm/platforms/hpu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm import envs
99
from vllm.logger import init_logger
10+
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1011

1112
from .interface import Platform, PlatformEnum, _Backend
1213

@@ -80,6 +81,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8081
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
8182
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
8283

84+
if vllm_config.model_config and vllm_config.model_config.use_mla:
85+
logger.info(
86+
"MLA is enabled on a non-GPU platform; forcing chunked "
87+
"prefill and prefix caching to be disabled.")
88+
vllm_config.scheduler_config.enable_chunked_prefill = False
89+
vllm_config.scheduler_config.chunked_prefill_enabled = False
90+
vllm_config.scheduler_config.max_num_batched_tokens = max(
91+
vllm_config.scheduler_config.max_model_len,
92+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
93+
8394
@classmethod
8495
def is_pin_memory_available(cls):
8596
logger.warning("Pin memory is not supported on HPU.")

vllm/platforms/neuron.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm import envs
88
from vllm.logger import init_logger
9+
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
910

1011
from .interface import Platform, PlatformEnum
1112

@@ -56,6 +57,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5657
vllm_config.cache_config.block_size = \
5758
vllm_config.model_config.max_model_len # type: ignore
5859

60+
if vllm_config.model_config and vllm_config.model_config.use_mla:
61+
logger.info(
62+
"MLA is enabled on a non-GPU platform; forcing chunked "
63+
"prefill and prefix caching to be disabled.")
64+
vllm_config.scheduler_config.enable_chunked_prefill = False
65+
vllm_config.scheduler_config.chunked_prefill_enabled = False
66+
vllm_config.scheduler_config.max_num_batched_tokens = max(
67+
vllm_config.scheduler_config.max_model_len,
68+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
69+
5970
@classmethod
6071
def is_pin_memory_available(cls) -> bool:
6172
logger.warning("Pin memory is not supported on Neuron.")

vllm/platforms/tpu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.inputs import ProcessorInputs, PromptType
1010
from vllm.logger import init_logger
1111
from vllm.sampling_params import SamplingParams, SamplingType
12+
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
1213

1314
from .interface import Platform, PlatformEnum, _Backend
1415

@@ -161,6 +162,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
161162
"Forcing --disable_chunked_mm_input.")
162163
scheduler_config.disable_chunked_mm_input = True
163164

165+
if vllm_config.model_config and vllm_config.model_config.use_mla:
166+
logger.info(
167+
"MLA is enabled on a non-GPU platform; forcing chunked "
168+
"prefill and prefix caching to be disabled.")
169+
vllm_config.scheduler_config.enable_chunked_prefill = False
170+
vllm_config.scheduler_config.chunked_prefill_enabled = False
171+
vllm_config.scheduler_config.max_num_batched_tokens = max(
172+
vllm_config.scheduler_config.max_model_len,
173+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
174+
164175
@classmethod
165176
def is_pin_memory_available(cls):
166177
logger.warning("Pin memory is not supported on TPU.")

vllm/platforms/xpu.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm.logger import init_logger
8+
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
89

910
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
1011

@@ -113,6 +114,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
113114
parallel_config.distributed_executor_backend)
114115
parallel_config.distributed_executor_backend = "ray"
115116

117+
if vllm_config.model_config and vllm_config.model_config.use_mla:
118+
logger.info(
119+
"MLA is enabled on a non-GPU platform; forcing chunked "
120+
"prefill and prefix caching to be disabled.")
121+
vllm_config.scheduler_config.enable_chunked_prefill = False
122+
vllm_config.scheduler_config.chunked_prefill_enabled = False
123+
vllm_config.scheduler_config.max_num_batched_tokens = max(
124+
vllm_config.scheduler_config.max_model_len,
125+
DEFAULT_MAX_NUM_BATCHED_TOKENS)
126+
116127
@classmethod
117128
def is_pin_memory_available(cls):
118129
logger.warning("Pin memory is not supported on XPU.")

vllm/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@
7777

7878
logger = init_logger(__name__)
7979

80+
# This value is chosen to have a balance between ITL and TTFT. Note it is
81+
# not optimized for throughput.
82+
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
83+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
84+
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
85+
8086
# Exception strings for non-implemented encoder/decoder scenarios
8187

8288
# Reminder: Please update docs/source/features/compatibility_matrix.md

0 commit comments

Comments
 (0)